feat: Add rerank models and rerank API change (#3831)

# What does this PR do?
<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->
- Extend the model type to include rerank models.
- Implement `rerank()` method in inference router.
- Add `rerank_model_list` to `OpenAIMixin` to enable providers to
register and identify rerank models
- Update documentation.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->
```
pytest tests/unit/providers/utils/inference/test_openai_mixin.py
```
This commit is contained in:
Jiayi Ni 2025-10-22 12:02:28 -07:00 committed by GitHub
parent f2598d30e6
commit bb1ebb3c6b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 186 additions and 43 deletions

View file

@ -3,9 +3,10 @@ description: "Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings. Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search." - Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models reorder the documents based on their relevance to a query."
sidebar_label: Inference sidebar_label: Inference
title: Inference title: Inference
--- ---
@ -18,8 +19,9 @@ Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings. Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search. - Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models reorder the documents based on their relevance to a query.
This section contains documentation for all available providers for the **inference** API. This section contains documentation for all available providers for the **inference** API.

View file

@ -13467,7 +13467,7 @@
}, },
{ {
"name": "Inference", "name": "Inference",
"description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.",
"x-displayName": "Inference" "x-displayName": "Inference"
}, },
{ {

View file

@ -10218,13 +10218,16 @@ tags:
embeddings. embeddings.
This API provides the raw interface to the underlying models. Two kinds of models This API provides the raw interface to the underlying models. Three kinds of
are supported: models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic - Embedding models: these models generate embeddings to be used for semantic
search. search.
- Rerank models: these models reorder the documents based on their relevance
to a query.
x-displayName: Inference x-displayName: Inference
- name: Models - name: Models
description: '' description: ''

View file

@ -6859,7 +6859,8 @@
"type": "string", "type": "string",
"enum": [ "enum": [
"llm", "llm",
"embedding" "embedding",
"rerank"
], ],
"title": "ModelType", "title": "ModelType",
"description": "Enumeration of supported model types in Llama Stack." "description": "Enumeration of supported model types in Llama Stack."
@ -13269,7 +13270,7 @@
}, },
{ {
"name": "Inference", "name": "Inference",
"description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.",
"x-displayName": "Inference" "x-displayName": "Inference"
}, },
{ {

View file

@ -5269,6 +5269,7 @@ components:
enum: enum:
- llm - llm
- embedding - embedding
- rerank
title: ModelType title: ModelType
description: >- description: >-
Enumeration of supported model types in Llama Stack. Enumeration of supported model types in Llama Stack.
@ -10190,13 +10191,16 @@ tags:
embeddings. embeddings.
This API provides the raw interface to the underlying models. Two kinds of models This API provides the raw interface to the underlying models. Three kinds of
are supported: models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic - Embedding models: these models generate embeddings to be used for semantic
search. search.
- Rerank models: these models reorder the documents based on their relevance
to a query.
x-displayName: Inference x-displayName: Inference
- name: Inspect - name: Inspect
description: >- description: >-

View file

@ -8531,7 +8531,8 @@
"type": "string", "type": "string",
"enum": [ "enum": [
"llm", "llm",
"embedding" "embedding",
"rerank"
], ],
"title": "ModelType", "title": "ModelType",
"description": "Enumeration of supported model types in Llama Stack." "description": "Enumeration of supported model types in Llama Stack."
@ -17959,7 +17960,7 @@
}, },
{ {
"name": "Inference", "name": "Inference",
"description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.",
"x-displayName": "Inference" "x-displayName": "Inference"
}, },
{ {

View file

@ -6482,6 +6482,7 @@ components:
enum: enum:
- llm - llm
- embedding - embedding
- rerank
title: ModelType title: ModelType
description: >- description: >-
Enumeration of supported model types in Llama Stack. Enumeration of supported model types in Llama Stack.
@ -13585,13 +13586,16 @@ tags:
embeddings. embeddings.
This API provides the raw interface to the underlying models. Two kinds of models This API provides the raw interface to the underlying models. Three kinds of
are supported: models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic - Embedding models: these models generate embeddings to be used for semantic
search. search.
- Rerank models: these models reorder the documents based on their relevance
to a query.
x-displayName: Inference x-displayName: Inference
- name: Inspect - name: Inspect
description: >- description: >-

View file

@ -1234,9 +1234,10 @@ class Inference(InferenceProvider):
Llama Stack Inference API for generating completions, chat completions, and embeddings. Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search. - Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models reorder the documents based on their relevance to a query.
""" """
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True) @webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)

View file

@ -27,10 +27,12 @@ class ModelType(StrEnum):
"""Enumeration of supported model types in Llama Stack. """Enumeration of supported model types in Llama Stack.
:cvar llm: Large language model for text generation and completion :cvar llm: Large language model for text generation and completion
:cvar embedding: Embedding model for converting text to vector representations :cvar embedding: Embedding model for converting text to vector representations
:cvar rerank: Reranking model for reordering documents based on their relevance to a query
""" """
llm = "llm" llm = "llm"
embedding = "embedding" embedding = "embedding"
rerank = "rerank"
@json_schema_type @json_schema_type

View file

@ -44,9 +44,14 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIMessageParam, OpenAIMessageParam,
Order, Order,
RerankResponse,
StopReason, StopReason,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -182,6 +187,23 @@ class InferenceRouter(Inference):
raise ModelTypeError(model_id, model.model_type, expected_model_type) raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model return model
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
logger.debug(f"InferenceRouter.rerank: {model}")
model_obj = await self._get_model(model, ModelType.rerank)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.rerank(
model=model_obj.identifier,
query=query,
items=items,
max_num_results=max_num_results,
)
async def openai_completion( async def openai_completion(
self, self,
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)], params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],

View file

@ -48,6 +48,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
- overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses - overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses
- download_images: If True, downloads images and converts to base64 for providers that require it - download_images: If True, downloads images and converts to base64 for providers that require it
- embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata - embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata
- construct_model_from_identifier: Method to construct a Model instance corresponding to the given identifier
- provider_data_api_key_field: Optional field name in provider data to look for API key - provider_data_api_key_field: Optional field name in provider data to look for API key
- list_provider_model_ids: Method to list available models from the provider - list_provider_model_ids: Method to list available models from the provider
- get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client - get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client
@ -121,6 +122,30 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
""" """
return {} return {}
def construct_model_from_identifier(self, identifier: str) -> Model:
"""
Construct a Model instance corresponding to the given identifier
Child classes can override this to customize model typing/metadata.
:param identifier: The provider's model identifier
:return: A Model instance
"""
if metadata := self.embedding_model_metadata.get(identifier):
return Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=identifier,
identifier=identifier,
model_type=ModelType.embedding,
metadata=metadata,
)
return Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=identifier,
identifier=identifier,
model_type=ModelType.llm,
)
async def list_provider_model_ids(self) -> Iterable[str]: async def list_provider_model_ids(self) -> Iterable[str]:
""" """
List available models from the provider. List available models from the provider.
@ -416,21 +441,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
if self.allowed_models and provider_model_id not in self.allowed_models: if self.allowed_models and provider_model_id not in self.allowed_models:
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list") logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
continue continue
if metadata := self.embedding_model_metadata.get(provider_model_id): model = self.construct_model_from_identifier(provider_model_id)
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=provider_model_id,
identifier=provider_model_id,
model_type=ModelType.embedding,
metadata=metadata,
)
else:
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=provider_model_id,
identifier=provider_model_id,
model_type=ModelType.llm,
)
self._model_cache[provider_model_id] = model self._model_cache[provider_model_id] = model
return list(self._model_cache.values()) return list(self._model_cache.values())

View file

@ -38,6 +38,28 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
} }
class OpenAIMixinWithCustomModelConstruction(OpenAIMixinImpl):
"""Test implementation that uses construct_model_from_identifier to add rerank models"""
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
# Adds rerank models via construct_model_from_identifier
rerank_model_ids: set[str] = {"rerank-model-1", "rerank-model-2"}
def construct_model_from_identifier(self, identifier: str) -> Model:
if identifier in self.rerank_model_ids:
return Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=identifier,
identifier=identifier,
model_type=ModelType.rerank,
)
return super().construct_model_from_identifier(identifier)
@pytest.fixture @pytest.fixture
def mixin(): def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store""" """Create a test instance of OpenAIMixin with mocked model_store"""
@ -62,6 +84,13 @@ def mixin_with_embeddings():
return OpenAIMixinWithEmbeddingsImpl(config=config) return OpenAIMixinWithEmbeddingsImpl(config=config)
@pytest.fixture
def mixin_with_custom_model_construction():
"""Create a test instance using custom construct_model_from_identifier"""
config = RemoteInferenceProviderConfig()
return OpenAIMixinWithCustomModelConstruction(config=config)
@pytest.fixture @pytest.fixture
def mock_models(): def mock_models():
"""Create multiple mock OpenAI model objects""" """Create multiple mock OpenAI model objects"""
@ -113,6 +142,19 @@ def mock_client_context():
return _mock_client_context return _mock_client_context
def _assert_models_match_expected(actual_models, expected_models):
"""Verify the models match expected attributes.
Args:
actual_models: List of models to verify
expected_models: Mapping of model identifier to expected attribute values
"""
for identifier, expected_attrs in expected_models.items():
model = next(m for m in actual_models if m.identifier == identifier)
for attr_name, expected_value in expected_attrs.items():
assert getattr(model, attr_name) == expected_value
class TestOpenAIMixinListModels: class TestOpenAIMixinListModels:
"""Test cases for the list_models method""" """Test cases for the list_models method"""
@ -342,21 +384,71 @@ class TestOpenAIMixinEmbeddingModelMetadata:
assert result is not None assert result is not None
assert len(result) == 2 assert len(result) == 2
# Find the models in the result expected_models = {
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small") "text-embedding-3-small": {
llm_model = next(m for m in result if m.identifier == "gpt-4") "model_type": ModelType.embedding,
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
"provider_id": "test-provider",
"provider_resource_id": "text-embedding-3-small",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
# Check embedding model _assert_models_match_expected(result, expected_models)
assert embedding_model.model_type == ModelType.embedding
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
assert embedding_model.provider_id == "test-provider"
assert embedding_model.provider_resource_id == "text-embedding-3-small"
# Check LLM model
assert llm_model.model_type == ModelType.llm class TestOpenAIMixinCustomModelConstruction:
assert llm_model.metadata == {} # No metadata for LLMs """Test cases for mixed model types (LLM, embedding, rerank) through construct_model_from_identifier"""
assert llm_model.provider_id == "test-provider"
assert llm_model.provider_resource_id == "gpt-4" async def test_mixed_model_types_identification(self, mixin_with_custom_model_construction, mock_client_context):
"""Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata"""
# Create mock models: 1 embedding, 1 rerank, 1 LLM
mock_embedding_model = MagicMock(id="text-embedding-3-small")
mock_rerank_model = MagicMock(id="rerank-model-1")
mock_llm_model = MagicMock(id="gpt-4")
mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
with mock_client_context(mixin_with_custom_model_construction, mock_client):
result = await mixin_with_custom_model_construction.list_models()
assert result is not None
assert len(result) == 3
expected_models = {
"text-embedding-3-small": {
"model_type": ModelType.embedding,
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
"provider_id": "test-provider",
"provider_resource_id": "text-embedding-3-small",
},
"rerank-model-1": {
"model_type": ModelType.rerank,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "rerank-model-1",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
_assert_models_match_expected(result, expected_models)
class TestOpenAIMixinAllowedModels: class TestOpenAIMixinAllowedModels: