Add rerank models and rerank API change

This commit is contained in:
Jiayi 2025-10-16 17:27:38 -07:00
parent f675fdda0f
commit 51c923f096
12 changed files with 215 additions and 28 deletions

View file

@ -3,9 +3,10 @@ description: "Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings. Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search." - Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models (Experimental): these models reorder the documents based on their relevance to a query."
sidebar_label: Inference sidebar_label: Inference
title: Inference title: Inference
--- ---
@ -18,8 +19,9 @@ Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings. Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search. - Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models (Experimental): these models reorder the documents based on their relevance to a query.
This section contains documentation for all available providers for the **inference** API. This section contains documentation for all available providers for the **inference** API.

View file

@ -13459,7 +13459,7 @@
}, },
{ {
"name": "Inference", "name": "Inference",
"description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models (Experimental): these models reorder the documents based on their relevance to a query.",
"x-displayName": "Inference" "x-displayName": "Inference"
}, },
{ {

View file

@ -10210,13 +10210,16 @@ tags:
embeddings. embeddings.
This API provides the raw interface to the underlying models. Two kinds of models This API provides the raw interface to the underlying models. Three kinds of
are supported: models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic - Embedding models: these models generate embeddings to be used for semantic
search. search.
- Rerank models (Experimental): these models reorder the documents based on
their relevance to a query.
x-displayName: Inference x-displayName: Inference
- name: Models - name: Models
description: '' description: ''

View file

@ -6859,7 +6859,8 @@
"type": "string", "type": "string",
"enum": [ "enum": [
"llm", "llm",
"embedding" "embedding",
"rerank"
], ],
"title": "ModelType", "title": "ModelType",
"description": "Enumeration of supported model types in Llama Stack." "description": "Enumeration of supported model types in Llama Stack."
@ -13261,7 +13262,7 @@
}, },
{ {
"name": "Inference", "name": "Inference",
"description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models (Experimental): these models reorder the documents based on their relevance to a query.",
"x-displayName": "Inference" "x-displayName": "Inference"
}, },
{ {

View file

@ -5269,6 +5269,7 @@ components:
enum: enum:
- llm - llm
- embedding - embedding
- rerank
title: ModelType title: ModelType
description: >- description: >-
Enumeration of supported model types in Llama Stack. Enumeration of supported model types in Llama Stack.
@ -10182,13 +10183,16 @@ tags:
embeddings. embeddings.
This API provides the raw interface to the underlying models. Two kinds of models This API provides the raw interface to the underlying models. Three kinds of
are supported: models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic - Embedding models: these models generate embeddings to be used for semantic
search. search.
- Rerank models (Experimental): these models reorder the documents based on
their relevance to a query.
x-displayName: Inference x-displayName: Inference
- name: Inspect - name: Inspect
description: >- description: >-

View file

@ -8531,7 +8531,8 @@
"type": "string", "type": "string",
"enum": [ "enum": [
"llm", "llm",
"embedding" "embedding",
"rerank"
], ],
"title": "ModelType", "title": "ModelType",
"description": "Enumeration of supported model types in Llama Stack." "description": "Enumeration of supported model types in Llama Stack."
@ -17951,7 +17952,7 @@
}, },
{ {
"name": "Inference", "name": "Inference",
"description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models (Experimental): these models reorder the documents based on their relevance to a query.",
"x-displayName": "Inference" "x-displayName": "Inference"
}, },
{ {

View file

@ -6482,6 +6482,7 @@ components:
enum: enum:
- llm - llm
- embedding - embedding
- rerank
title: ModelType title: ModelType
description: >- description: >-
Enumeration of supported model types in Llama Stack. Enumeration of supported model types in Llama Stack.
@ -13577,13 +13578,16 @@ tags:
embeddings. embeddings.
This API provides the raw interface to the underlying models. Two kinds of models This API provides the raw interface to the underlying models. Three kinds of
are supported: models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic - Embedding models: these models generate embeddings to be used for semantic
search. search.
- Rerank models (Experimental): these models reorder the documents based on
their relevance to a query.
x-displayName: Inference x-displayName: Inference
- name: Inspect - name: Inspect
description: >- description: >-

View file

@ -1234,9 +1234,10 @@ class Inference(InferenceProvider):
Llama Stack Inference API for generating completions, chat completions, and embeddings. Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search. - Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models (Experimental): these models reorder the documents based on their relevance to a query.
""" """
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True) @webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)

View file

@ -27,10 +27,12 @@ class ModelType(StrEnum):
"""Enumeration of supported model types in Llama Stack. """Enumeration of supported model types in Llama Stack.
:cvar llm: Large language model for text generation and completion :cvar llm: Large language model for text generation and completion
:cvar embedding: Embedding model for converting text to vector representations :cvar embedding: Embedding model for converting text to vector representations
:cvar rerank: Reranking model for reordering documents based on their relevance to a query
""" """
llm = "llm" llm = "llm"
embedding = "embedding" embedding = "embedding"
rerank = "rerank"
@json_schema_type @json_schema_type

View file

@ -44,9 +44,14 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIMessageParam, OpenAIMessageParam,
Order, Order,
RerankResponse,
StopReason, StopReason,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -182,6 +187,23 @@ class InferenceRouter(Inference):
raise ModelTypeError(model_id, model.model_type, expected_model_type) raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model return model
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
logger.debug(f"InferenceRouter.rerank: {model}")
model_obj = await self._get_model(model, ModelType.rerank)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.rerank(
model=model_obj.identifier,
query=query,
items=items,
max_num_results=max_num_results,
)
async def openai_completion( async def openai_completion(
self, self,
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)], params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],

View file

@ -78,6 +78,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}} # Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
embedding_model_metadata: dict[str, dict[str, int]] = {} embedding_model_metadata: dict[str, dict[str, int]] = {}
# List of rerank model IDs for this provider
# Can be set by subclasses or instances to provide rerank models
rerank_model_list: list[str] = []
# Cache of available models keyed by model ID # Cache of available models keyed by model ID
# This is set in list_models() and used in check_model_availability() # This is set in list_models() and used in check_model_availability()
_model_cache: dict[str, Model] = {} _model_cache: dict[str, Model] = {}
@ -424,6 +428,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata=metadata, metadata=metadata,
) )
elif provider_model_id in self.rerank_model_list:
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=provider_model_id,
identifier=provider_model_id,
model_type=ModelType.rerank,
)
else: else:
model = Model( model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined] provider_id=self.__provider_id__, # type: ignore[attr-defined]

View file

@ -38,6 +38,23 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
} }
class OpenAIMixinWithRerankImpl(OpenAIMixinImpl):
"""Test implementation with rerank model list"""
rerank_model_list: list[str] = ["rerank-model-1", "rerank-model-2"]
class OpenAIMixinWithEmbeddingsAndRerankImpl(OpenAIMixinImpl):
"""Test implementation with both embedding model metadata and rerank model list"""
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
rerank_model_list: list[str] = ["rerank-model-1", "rerank-model-2"]
@pytest.fixture @pytest.fixture
def mixin(): def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store""" """Create a test instance of OpenAIMixin with mocked model_store"""
@ -62,6 +79,20 @@ def mixin_with_embeddings():
return OpenAIMixinWithEmbeddingsImpl(config=config) return OpenAIMixinWithEmbeddingsImpl(config=config)
@pytest.fixture
def mixin_with_rerank():
"""Create a test instance of OpenAIMixin with rerank model list"""
config = RemoteInferenceProviderConfig()
return OpenAIMixinWithRerankImpl(config=config)
@pytest.fixture
def mixin_with_embeddings_and_rerank():
"""Create a test instance of OpenAIMixin with both embedding model metadata and rerank model list"""
config = RemoteInferenceProviderConfig()
return OpenAIMixinWithEmbeddingsAndRerankImpl(config=config)
@pytest.fixture @pytest.fixture
def mock_models(): def mock_models():
"""Create multiple mock OpenAI model objects""" """Create multiple mock OpenAI model objects"""
@ -113,6 +144,19 @@ def mock_client_context():
return _mock_client_context return _mock_client_context
def _assert_models_match_expected(actual_models, expected_models):
"""Verify the models match expected attributes.
Args:
actual_models: List of models to verify
expected_models: Mapping of model identifier to expected attribute values
"""
for identifier, expected_attrs in expected_models.items():
model = next(m for m in actual_models if m.identifier == identifier)
for attr_name, expected_value in expected_attrs.items():
assert getattr(model, attr_name) == expected_value
class TestOpenAIMixinListModels: class TestOpenAIMixinListModels:
"""Test cases for the list_models method""" """Test cases for the list_models method"""
@ -342,21 +386,113 @@ class TestOpenAIMixinEmbeddingModelMetadata:
assert result is not None assert result is not None
assert len(result) == 2 assert len(result) == 2
# Find the models in the result expected_models = {
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small") "text-embedding-3-small": {
llm_model = next(m for m in result if m.identifier == "gpt-4") "model_type": ModelType.embedding,
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
"provider_id": "test-provider",
"provider_resource_id": "text-embedding-3-small",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
# Check embedding model _assert_models_match_expected(result, expected_models)
assert embedding_model.model_type == ModelType.embedding
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
assert embedding_model.provider_id == "test-provider"
assert embedding_model.provider_resource_id == "text-embedding-3-small"
# Check LLM model
assert llm_model.model_type == ModelType.llm class TestOpenAIMixinRerankModelList:
assert llm_model.metadata == {} # No metadata for LLMs """Test cases for rerank_model_list attribute functionality"""
assert llm_model.provider_id == "test-provider"
assert llm_model.provider_resource_id == "gpt-4" async def test_rerank_model_identified(self, mixin_with_rerank, mock_client_context):
"""Test that models in rerank_model_list are correctly identified as rerank models"""
# Create mock models: 1 rerank model and 1 LLM
mock_rerank_model = MagicMock(id="rerank-model-1")
mock_llm_model = MagicMock(id="gpt-4")
mock_models = [mock_rerank_model, mock_llm_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
with mock_client_context(mixin_with_rerank, mock_client):
result = await mixin_with_rerank.list_models()
assert result is not None
assert len(result) == 2
expected_models = {
"rerank-model-1": {
"model_type": ModelType.rerank,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "rerank-model-1",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
_assert_models_match_expected(result, expected_models)
class TestOpenAIMixinMixedModelTypes:
"""Test cases for mixed model types (LLM, embedding, rerank)"""
async def test_mixed_model_types_identification(self, mixin_with_embeddings_and_rerank, mock_client_context):
"""Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata"""
# Create mock models: 1 embedding, 1 rerank, 1 LLM
mock_embedding_model = MagicMock(id="text-embedding-3-small")
mock_rerank_model = MagicMock(id="rerank-model-1")
mock_llm_model = MagicMock(id="gpt-4")
mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
with mock_client_context(mixin_with_embeddings_and_rerank, mock_client):
result = await mixin_with_embeddings_and_rerank.list_models()
assert result is not None
assert len(result) == 3
expected_models = {
"text-embedding-3-small": {
"model_type": ModelType.embedding,
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
"provider_id": "test-provider",
"provider_resource_id": "text-embedding-3-small",
},
"rerank-model-1": {
"model_type": ModelType.rerank,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "rerank-model-1",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
_assert_models_match_expected(result, expected_models)
class TestOpenAIMixinAllowedModels: class TestOpenAIMixinAllowedModels: