From bb1ebb3c6b6aff8a5e5b0d20198f4268fe91f817 Mon Sep 17 00:00:00 2001 From: Jiayi Ni Date: Wed, 22 Oct 2025 12:02:28 -0700 Subject: [PATCH] feat: Add rerank models and rerank API change (#3831) # What does this PR do? - Extend the model type to include rerank models. - Implement `rerank()` method in inference router. - Add `rerank_model_list` to `OpenAIMixin` to enable providers to register and identify rerank models - Update documentation. ## Test Plan ``` pytest tests/unit/providers/utils/inference/test_openai_mixin.py ``` --- docs/docs/providers/inference/index.mdx | 8 +- docs/static/deprecated-llama-stack-spec.html | 2 +- docs/static/deprecated-llama-stack-spec.yaml | 7 +- docs/static/llama-stack-spec.html | 5 +- docs/static/llama-stack-spec.yaml | 8 +- docs/static/stainless-llama-stack-spec.html | 5 +- docs/static/stainless-llama-stack-spec.yaml | 8 +- llama_stack/apis/inference/inference.py | 3 +- llama_stack/apis/models/models.py | 2 + llama_stack/core/routers/inference.py | 22 ++++ .../providers/utils/inference/openai_mixin.py | 41 +++--- .../utils/inference/test_openai_mixin.py | 118 ++++++++++++++++-- 12 files changed, 186 insertions(+), 43 deletions(-) diff --git a/docs/docs/providers/inference/index.mdx b/docs/docs/providers/inference/index.mdx index c2bf69962..478611420 100644 --- a/docs/docs/providers/inference/index.mdx +++ b/docs/docs/providers/inference/index.mdx @@ -3,9 +3,10 @@ description: "Inference Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: + This API provides the raw interface to the underlying models. Three kinds of models are supported: - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search." + - Embedding models: these models generate embeddings to be used for semantic search. + - Rerank models: these models reorder the documents based on their relevance to a query." sidebar_label: Inference title: Inference --- @@ -18,8 +19,9 @@ Inference Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: + This API provides the raw interface to the underlying models. Three kinds of models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + - Rerank models: these models reorder the documents based on their relevance to a query. This section contains documentation for all available providers for the **inference** API. diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index d920317cf..e3e182dd7 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -13467,7 +13467,7 @@ }, { "name": "Inference", - "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", + "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.", "x-displayName": "Inference" }, { diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 66b2caeca..6b5b8230a 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -10218,13 +10218,16 @@ tags: embeddings. - This API provides the raw interface to the underlying models. Two kinds of models - are supported: + This API provides the raw interface to the underlying models. Three kinds of + models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + + - Rerank models: these models reorder the documents based on their relevance + to a query. x-displayName: Inference - name: Models description: '' diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 61deaec1e..584127d91 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -6859,7 +6859,8 @@ "type": "string", "enum": [ "llm", - "embedding" + "embedding", + "rerank" ], "title": "ModelType", "description": "Enumeration of supported model types in Llama Stack." @@ -13269,7 +13270,7 @@ }, { "name": "Inference", - "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", + "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.", "x-displayName": "Inference" }, { diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index c6197b36f..90b1b3a2e 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -5269,6 +5269,7 @@ components: enum: - llm - embedding + - rerank title: ModelType description: >- Enumeration of supported model types in Llama Stack. @@ -10190,13 +10191,16 @@ tags: embeddings. - This API provides the raw interface to the underlying models. Two kinds of models - are supported: + This API provides the raw interface to the underlying models. Three kinds of + models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + + - Rerank models: these models reorder the documents based on their relevance + to a query. x-displayName: Inference - name: Inspect description: >- diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 38122ebc0..f2d99a9c7 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -8531,7 +8531,8 @@ "type": "string", "enum": [ "llm", - "embedding" + "embedding", + "rerank" ], "title": "ModelType", "description": "Enumeration of supported model types in Llama Stack." @@ -17959,7 +17960,7 @@ }, { "name": "Inference", - "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", + "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.", "x-displayName": "Inference" }, { diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 93049a14a..9fe6cb6a3 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -6482,6 +6482,7 @@ components: enum: - llm - embedding + - rerank title: ModelType description: >- Enumeration of supported model types in Llama Stack. @@ -13585,13 +13586,16 @@ tags: embeddings. - This API provides the raw interface to the underlying models. Two kinds of models - are supported: + This API provides the raw interface to the underlying models. Three kinds of + models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + + - Rerank models: these models reorder the documents based on their relevance + to a query. x-displayName: Inference - name: Inspect description: >- diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 027246470..049482837 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -1234,9 +1234,10 @@ class Inference(InferenceProvider): Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: + This API provides the raw interface to the underlying models. Three kinds of models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + - Rerank models: these models reorder the documents based on their relevance to a query. """ @webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True) diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 10949cb95..5486e3bfd 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -27,10 +27,12 @@ class ModelType(StrEnum): """Enumeration of supported model types in Llama Stack. :cvar llm: Large language model for text generation and completion :cvar embedding: Embedding model for converting text to vector representations + :cvar rerank: Reranking model for reordering documents based on their relevance to a query """ llm = "llm" embedding = "embedding" + rerank = "rerank" @json_schema_type diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index b20ad44ca..09241d836 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -44,9 +44,14 @@ from llama_stack.apis.inference import ( OpenAIEmbeddingsResponse, OpenAIMessageParam, Order, + RerankResponse, StopReason, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, +) from llama_stack.apis.models import Model, ModelType from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.log import get_logger @@ -182,6 +187,23 @@ class InferenceRouter(Inference): raise ModelTypeError(model_id, model.model_type, expected_model_type) return model + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + logger.debug(f"InferenceRouter.rerank: {model}") + model_obj = await self._get_model(model, ModelType.rerank) + provider = await self.routing_table.get_provider_impl(model_obj.identifier) + return await provider.rerank( + model=model_obj.identifier, + query=query, + items=items, + max_num_results=max_num_results, + ) + async def openai_completion( self, params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)], diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index a9ccc8091..bbd3d2e10 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -48,6 +48,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): - overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses - download_images: If True, downloads images and converts to base64 for providers that require it - embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata + - construct_model_from_identifier: Method to construct a Model instance corresponding to the given identifier - provider_data_api_key_field: Optional field name in provider data to look for API key - list_provider_model_ids: Method to list available models from the provider - get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client @@ -121,6 +122,30 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): """ return {} + def construct_model_from_identifier(self, identifier: str) -> Model: + """ + Construct a Model instance corresponding to the given identifier + + Child classes can override this to customize model typing/metadata. + + :param identifier: The provider's model identifier + :return: A Model instance + """ + if metadata := self.embedding_model_metadata.get(identifier): + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.embedding, + metadata=metadata, + ) + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.llm, + ) + async def list_provider_model_ids(self) -> Iterable[str]: """ List available models from the provider. @@ -416,21 +441,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): if self.allowed_models and provider_model_id not in self.allowed_models: logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list") continue - if metadata := self.embedding_model_metadata.get(provider_model_id): - model = Model( - provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=provider_model_id, - identifier=provider_model_id, - model_type=ModelType.embedding, - metadata=metadata, - ) - else: - model = Model( - provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=provider_model_id, - identifier=provider_model_id, - model_type=ModelType.llm, - ) + model = self.construct_model_from_identifier(provider_model_id) self._model_cache[provider_model_id] = model return list(self._model_cache.values()) diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 61a1f8f61..d98c096aa 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -38,6 +38,28 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl): } +class OpenAIMixinWithCustomModelConstruction(OpenAIMixinImpl): + """Test implementation that uses construct_model_from_identifier to add rerank models""" + + embedding_model_metadata: dict[str, dict[str, int]] = { + "text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192}, + "text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192}, + } + + # Adds rerank models via construct_model_from_identifier + rerank_model_ids: set[str] = {"rerank-model-1", "rerank-model-2"} + + def construct_model_from_identifier(self, identifier: str) -> Model: + if identifier in self.rerank_model_ids: + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.rerank, + ) + return super().construct_model_from_identifier(identifier) + + @pytest.fixture def mixin(): """Create a test instance of OpenAIMixin with mocked model_store""" @@ -62,6 +84,13 @@ def mixin_with_embeddings(): return OpenAIMixinWithEmbeddingsImpl(config=config) +@pytest.fixture +def mixin_with_custom_model_construction(): + """Create a test instance using custom construct_model_from_identifier""" + config = RemoteInferenceProviderConfig() + return OpenAIMixinWithCustomModelConstruction(config=config) + + @pytest.fixture def mock_models(): """Create multiple mock OpenAI model objects""" @@ -113,6 +142,19 @@ def mock_client_context(): return _mock_client_context +def _assert_models_match_expected(actual_models, expected_models): + """Verify the models match expected attributes. + + Args: + actual_models: List of models to verify + expected_models: Mapping of model identifier to expected attribute values + """ + for identifier, expected_attrs in expected_models.items(): + model = next(m for m in actual_models if m.identifier == identifier) + for attr_name, expected_value in expected_attrs.items(): + assert getattr(model, attr_name) == expected_value + + class TestOpenAIMixinListModels: """Test cases for the list_models method""" @@ -342,21 +384,71 @@ class TestOpenAIMixinEmbeddingModelMetadata: assert result is not None assert len(result) == 2 - # Find the models in the result - embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small") - llm_model = next(m for m in result if m.identifier == "gpt-4") + expected_models = { + "text-embedding-3-small": { + "model_type": ModelType.embedding, + "metadata": {"embedding_dimension": 1536, "context_length": 8192}, + "provider_id": "test-provider", + "provider_resource_id": "text-embedding-3-small", + }, + "gpt-4": { + "model_type": ModelType.llm, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "gpt-4", + }, + } - # Check embedding model - assert embedding_model.model_type == ModelType.embedding - assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192} - assert embedding_model.provider_id == "test-provider" - assert embedding_model.provider_resource_id == "text-embedding-3-small" + _assert_models_match_expected(result, expected_models) - # Check LLM model - assert llm_model.model_type == ModelType.llm - assert llm_model.metadata == {} # No metadata for LLMs - assert llm_model.provider_id == "test-provider" - assert llm_model.provider_resource_id == "gpt-4" + +class TestOpenAIMixinCustomModelConstruction: + """Test cases for mixed model types (LLM, embedding, rerank) through construct_model_from_identifier""" + + async def test_mixed_model_types_identification(self, mixin_with_custom_model_construction, mock_client_context): + """Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata""" + # Create mock models: 1 embedding, 1 rerank, 1 LLM + mock_embedding_model = MagicMock(id="text-embedding-3-small") + mock_rerank_model = MagicMock(id="rerank-model-1") + mock_llm_model = MagicMock(id="gpt-4") + mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model] + + mock_client = MagicMock() + + async def mock_models_list(): + for model in mock_models: + yield model + + mock_client.models.list.return_value = mock_models_list() + + with mock_client_context(mixin_with_custom_model_construction, mock_client): + result = await mixin_with_custom_model_construction.list_models() + + assert result is not None + assert len(result) == 3 + + expected_models = { + "text-embedding-3-small": { + "model_type": ModelType.embedding, + "metadata": {"embedding_dimension": 1536, "context_length": 8192}, + "provider_id": "test-provider", + "provider_resource_id": "text-embedding-3-small", + }, + "rerank-model-1": { + "model_type": ModelType.rerank, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "rerank-model-1", + }, + "gpt-4": { + "model_type": ModelType.llm, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "gpt-4", + }, + } + + _assert_models_match_expected(result, expected_models) class TestOpenAIMixinAllowedModels: