mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Merge 6b4940806f
into 188a56af5c
This commit is contained in:
commit
1e04f105f2
20 changed files with 840 additions and 34 deletions
|
@ -1,9 +1,10 @@
|
|||
---
|
||||
description: "Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
This API provides the raw interface to the underlying models. Three kinds of models are supported:
|
||||
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
|
||||
- Embedding models: these models generate embeddings to be used for semantic search."
|
||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
- Rerank models: these models reorder the documents based on their relevance to a query."
|
||||
sidebar_label: Inference
|
||||
title: Inference
|
||||
---
|
||||
|
@ -14,8 +15,9 @@ title: Inference
|
|||
|
||||
Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
This API provides the raw interface to the underlying models. Three kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
- Rerank models: these models reorder the documents based on their relevance to a query.
|
||||
|
||||
This section contains documentation for all available providers for the **inference** API.
|
||||
|
|
2
docs/static/deprecated-llama-stack-spec.html
vendored
2
docs/static/deprecated-llama-stack-spec.html
vendored
|
@ -13335,7 +13335,7 @@
|
|||
},
|
||||
{
|
||||
"name": "Inference",
|
||||
"description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.",
|
||||
"description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.",
|
||||
"x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings."
|
||||
},
|
||||
{
|
||||
|
|
7
docs/static/deprecated-llama-stack-spec.yaml
vendored
7
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
@ -9990,13 +9990,16 @@ tags:
|
|||
description: ''
|
||||
- name: Inference
|
||||
description: >-
|
||||
This API provides the raw interface to the underlying models. Two kinds of models
|
||||
are supported:
|
||||
This API provides the raw interface to the underlying models. Three kinds of
|
||||
models are supported:
|
||||
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
|
||||
- Embedding models: these models generate embeddings to be used for semantic
|
||||
search.
|
||||
|
||||
- Rerank models: these models reorder the documents based on their relevance
|
||||
to a query.
|
||||
x-displayName: >-
|
||||
Llama Stack Inference API for generating completions, chat completions, and
|
||||
embeddings.
|
||||
|
|
|
@ -4992,7 +4992,7 @@
|
|||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the reranking model to use."
|
||||
"description": "The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint."
|
||||
},
|
||||
"query": {
|
||||
"oneOf": [
|
||||
|
|
|
@ -3657,7 +3657,8 @@ components:
|
|||
model:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the reranking model to use.
|
||||
The identifier of the reranking model to use. The model must be a reranking
|
||||
model registered with Llama Stack and available via the /models endpoint.
|
||||
query:
|
||||
oneOf:
|
||||
- type: string
|
||||
|
|
5
docs/static/llama-stack-spec.html
vendored
5
docs/static/llama-stack-spec.html
vendored
|
@ -6829,7 +6829,8 @@
|
|||
"type": "string",
|
||||
"enum": [
|
||||
"llm",
|
||||
"embedding"
|
||||
"embedding",
|
||||
"rerank"
|
||||
],
|
||||
"title": "ModelType",
|
||||
"description": "Enumeration of supported model types in Llama Stack."
|
||||
|
@ -12883,7 +12884,7 @@
|
|||
},
|
||||
{
|
||||
"name": "Inference",
|
||||
"description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.",
|
||||
"description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.",
|
||||
"x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings."
|
||||
},
|
||||
{
|
||||
|
|
8
docs/static/llama-stack-spec.yaml
vendored
8
docs/static/llama-stack-spec.yaml
vendored
|
@ -5158,6 +5158,7 @@ components:
|
|||
enum:
|
||||
- llm
|
||||
- embedding
|
||||
- rerank
|
||||
title: ModelType
|
||||
description: >-
|
||||
Enumeration of supported model types in Llama Stack.
|
||||
|
@ -9728,13 +9729,16 @@ tags:
|
|||
description: ''
|
||||
- name: Inference
|
||||
description: >-
|
||||
This API provides the raw interface to the underlying models. Two kinds of models
|
||||
are supported:
|
||||
This API provides the raw interface to the underlying models. Three kinds of
|
||||
models are supported:
|
||||
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
|
||||
- Embedding models: these models generate embeddings to be used for semantic
|
||||
search.
|
||||
|
||||
- Rerank models: these models reorder the documents based on their relevance
|
||||
to a query.
|
||||
x-displayName: >-
|
||||
Llama Stack Inference API for generating completions, chat completions, and
|
||||
embeddings.
|
||||
|
|
7
docs/static/stainless-llama-stack-spec.html
vendored
7
docs/static/stainless-llama-stack-spec.html
vendored
|
@ -8838,7 +8838,8 @@
|
|||
"type": "string",
|
||||
"enum": [
|
||||
"llm",
|
||||
"embedding"
|
||||
"embedding",
|
||||
"rerank"
|
||||
],
|
||||
"title": "ModelType",
|
||||
"description": "Enumeration of supported model types in Llama Stack."
|
||||
|
@ -17033,7 +17034,7 @@
|
|||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the reranking model to use."
|
||||
"description": "The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint."
|
||||
},
|
||||
"query": {
|
||||
"oneOf": [
|
||||
|
@ -18456,7 +18457,7 @@
|
|||
},
|
||||
{
|
||||
"name": "Inference",
|
||||
"description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.",
|
||||
"description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.",
|
||||
"x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings."
|
||||
},
|
||||
{
|
||||
|
|
11
docs/static/stainless-llama-stack-spec.yaml
vendored
11
docs/static/stainless-llama-stack-spec.yaml
vendored
|
@ -6603,6 +6603,7 @@ components:
|
|||
enum:
|
||||
- llm
|
||||
- embedding
|
||||
- rerank
|
||||
title: ModelType
|
||||
description: >-
|
||||
Enumeration of supported model types in Llama Stack.
|
||||
|
@ -12693,7 +12694,8 @@ components:
|
|||
model:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the reranking model to use.
|
||||
The identifier of the reranking model to use. The model must be a reranking
|
||||
model registered with Llama Stack and available via the /models endpoint.
|
||||
query:
|
||||
oneOf:
|
||||
- type: string
|
||||
|
@ -13774,13 +13776,16 @@ tags:
|
|||
description: ''
|
||||
- name: Inference
|
||||
description: >-
|
||||
This API provides the raw interface to the underlying models. Two kinds of models
|
||||
are supported:
|
||||
This API provides the raw interface to the underlying models. Three kinds of
|
||||
models are supported:
|
||||
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
|
||||
- Embedding models: these models generate embeddings to be used for semantic
|
||||
search.
|
||||
|
||||
- Rerank models: these models reorder the documents based on their relevance
|
||||
to a query.
|
||||
x-displayName: >-
|
||||
Llama Stack Inference API for generating completions, chat completions, and
|
||||
embeddings.
|
||||
|
|
|
@ -1016,7 +1016,7 @@ class InferenceProvider(Protocol):
|
|||
) -> RerankResponse:
|
||||
"""Rerank a list of documents based on their relevance to a query.
|
||||
|
||||
:param model: The identifier of the reranking model to use.
|
||||
:param model: The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint.
|
||||
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
|
||||
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
|
||||
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
|
||||
|
@ -1159,9 +1159,10 @@ class InferenceProvider(Protocol):
|
|||
class Inference(InferenceProvider):
|
||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
This API provides the raw interface to the underlying models. Three kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
- Rerank models: these models reorder the documents based on their relevance to a query.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
|
|
|
@ -27,10 +27,12 @@ class ModelType(StrEnum):
|
|||
"""Enumeration of supported model types in Llama Stack.
|
||||
:cvar llm: Large language model for text generation and completion
|
||||
:cvar embedding: Embedding model for converting text to vector representations
|
||||
:cvar rerank: Reranking model for reordering documents based on their relevance to a query
|
||||
"""
|
||||
|
||||
llm = "llm"
|
||||
embedding = "embedding"
|
||||
rerank = "rerank"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -41,9 +41,14 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
Order,
|
||||
RerankResponse,
|
||||
StopReason,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -179,6 +184,23 @@ class InferenceRouter(Inference):
|
|||
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
||||
return model
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
logger.debug(f"InferenceRouter.rerank: {model}")
|
||||
model_obj = await self._get_model(model, ModelType.rerank)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.rerank(
|
||||
model=model_obj.identifier,
|
||||
query=query,
|
||||
items=items,
|
||||
max_num_results=max_num_results,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -188,3 +188,22 @@ vlm_response = client.chat.completions.create(
|
|||
|
||||
print(f"VLM Response: {vlm_response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
### Rerank Example
|
||||
|
||||
The following example shows how to rerank documents using an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
rerank_response = client.inference.rerank(
|
||||
model="nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||
query="query",
|
||||
items=[
|
||||
"item_1",
|
||||
"item_2",
|
||||
"item_3",
|
||||
],
|
||||
)
|
||||
|
||||
for i, result in enumerate(rerank_response):
|
||||
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
|
||||
```
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import aiohttp
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -12,7 +13,14 @@ from llama_stack.apis.inference import (
|
|||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
RerankData,
|
||||
RerankResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
@ -44,6 +52,18 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
|
||||
}
|
||||
|
||||
rerank_model_list = [
|
||||
"nv-rerank-qa-mistral-4b:1",
|
||||
"nvidia/nv-rerankqa-mistral-4b-v3",
|
||||
"nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||
]
|
||||
|
||||
_rerank_model_endpoints = {
|
||||
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
|
||||
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
|
||||
}
|
||||
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
|
||||
|
@ -62,6 +82,8 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
# "Consider removing the api_key from the configuration."
|
||||
# )
|
||||
|
||||
super().__init__()
|
||||
|
||||
self._config = config
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
|
@ -80,6 +102,103 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
"""
|
||||
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
"""
|
||||
List available NVIDIA models by combining:
|
||||
1. Dynamic models from https://integrate.api.nvidia.com/v1/models
|
||||
2. Static rerank models (which use different API endpoints)
|
||||
"""
|
||||
self._model_cache = {}
|
||||
models = await super().list_models()
|
||||
|
||||
# Add rerank models
|
||||
existing_ids = {m.identifier for m in models}
|
||||
for model_id, _ in self._rerank_model_endpoints.items():
|
||||
if self.allowed_models and model_id not in self.allowed_models:
|
||||
continue
|
||||
if model_id not in existing_ids:
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=model_id,
|
||||
identifier=model_id,
|
||||
model_type=ModelType.rerank,
|
||||
)
|
||||
models.append(model)
|
||||
self._model_cache[model_id] = model
|
||||
|
||||
return models
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
ranking_url = self.get_base_url()
|
||||
|
||||
if _is_nvidia_hosted(self._config) and provider_model_id in self._rerank_model_endpoints:
|
||||
ranking_url = self._rerank_model_endpoints[provider_model_id]
|
||||
|
||||
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
|
||||
|
||||
# Convert query to text format
|
||||
if isinstance(query, str):
|
||||
query_text = query
|
||||
elif isinstance(query, OpenAIChatCompletionContentPartTextParam):
|
||||
query_text = query.text
|
||||
else:
|
||||
raise ValueError("Query must be a string or text content part")
|
||||
|
||||
# Convert items to text format
|
||||
passages = []
|
||||
for item in items:
|
||||
if isinstance(item, str):
|
||||
passages.append({"text": item})
|
||||
elif isinstance(item, OpenAIChatCompletionContentPartTextParam):
|
||||
passages.append({"text": item.text})
|
||||
else:
|
||||
raise ValueError("Items must be strings or text content parts")
|
||||
|
||||
payload = {
|
||||
"model": provider_model_id,
|
||||
"query": {"text": query_text},
|
||||
"passages": passages,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.get_api_key()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(ranking_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
response_text = await response.text()
|
||||
raise ConnectionError(
|
||||
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
|
||||
)
|
||||
|
||||
result = await response.json()
|
||||
rankings = result.get("rankings", [])
|
||||
|
||||
# Convert to RerankData format
|
||||
rerank_data = []
|
||||
for ranking in rankings:
|
||||
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
|
||||
|
||||
# Apply max_num_results limit
|
||||
if max_num_results is not None:
|
||||
rerank_data = rerank_data[:max_num_results]
|
||||
|
||||
return RerankResponse(data=rerank_data)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -63,6 +63,10 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
|
|||
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {}
|
||||
|
||||
# List of rerank model IDs for this provider
|
||||
# Can be set by subclasses or instances to provide rerank models
|
||||
rerank_model_list: list[str] = []
|
||||
|
||||
# Cache of available models keyed by model ID
|
||||
# This is set in list_models() and used in check_model_availability()
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
@ -400,6 +404,14 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
|
|||
model_type=ModelType.embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
elif m.id in self.rerank_model_list:
|
||||
# This is a rerank model
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.rerank,
|
||||
)
|
||||
else:
|
||||
# This is an LLM
|
||||
model = Model(
|
||||
|
|
|
@ -120,6 +120,10 @@ def pytest_addoption(parser):
|
|||
"--embedding-model",
|
||||
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
|
||||
)
|
||||
parser.addoption(
|
||||
"--rerank-model",
|
||||
help="comma-separated list of rerank models. Fixture name: rerank_model_id",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-shield",
|
||||
help="comma-separated list of safety shields. Fixture name: shield_id",
|
||||
|
@ -198,6 +202,7 @@ def pytest_generate_tests(metafunc):
|
|||
"shield_id": ("--safety-shield", "shield"),
|
||||
"judge_model_id": ("--judge-model", "judge"),
|
||||
"embedding_dimension": ("--embedding-dimension", "dim"),
|
||||
"rerank_model_id": ("--rerank-model", "rerank"),
|
||||
}
|
||||
|
||||
# Collect all parameters and their values
|
||||
|
|
|
@ -119,6 +119,7 @@ def client_with_models(
|
|||
embedding_model_id,
|
||||
embedding_dimension,
|
||||
judge_model_id,
|
||||
rerank_model_id,
|
||||
):
|
||||
client = llama_stack_client
|
||||
|
||||
|
@ -151,6 +152,20 @@ def client_with_models(
|
|||
model_type="embedding",
|
||||
metadata={"embedding_dimension": embedding_dimension or 384},
|
||||
)
|
||||
if rerank_model_id and rerank_model_id not in model_ids:
|
||||
selected_provider = None
|
||||
for p in providers:
|
||||
# Currently only NVIDIA inference provider supports reranking
|
||||
if p.provider_type == "remote::nvidia":
|
||||
selected_provider = p
|
||||
break
|
||||
|
||||
selected_provider = selected_provider or providers[0]
|
||||
client.models.register(
|
||||
model_id=rerank_model_id,
|
||||
provider_id=selected_provider.provider_id,
|
||||
model_type="rerank",
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
|
@ -166,7 +181,14 @@ def model_providers(llama_stack_client):
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_if_no_model(request):
|
||||
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"]
|
||||
model_fixtures = [
|
||||
"text_model_id",
|
||||
"vision_model_id",
|
||||
"embedding_model_id",
|
||||
"judge_model_id",
|
||||
"shield_id",
|
||||
"rerank_model_id",
|
||||
]
|
||||
test_func = request.node.function
|
||||
|
||||
actual_params = inspect.signature(test_func).parameters.keys()
|
||||
|
|
214
tests/integration/inference/test_rerank.py
Normal file
214
tests/integration/inference/test_rerank.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||
from llama_stack_client.types.alpha import InferenceRerankResponse
|
||||
from llama_stack_client.types.shared.interleaved_content import (
|
||||
ImageContentItem,
|
||||
ImageContentItemImage,
|
||||
ImageContentItemImageURL,
|
||||
TextContentItem,
|
||||
)
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
# Test data
|
||||
DUMMY_STRING = "string_1"
|
||||
DUMMY_STRING2 = "string_2"
|
||||
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
|
||||
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
|
||||
DUMMY_IMAGE_URL = ImageContentItem(
|
||||
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
|
||||
)
|
||||
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
||||
|
||||
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
|
||||
|
||||
|
||||
def skip_if_provider_doesnt_support_rerank(inference_provider_type):
|
||||
supported_providers = {"remote::nvidia"}
|
||||
if inference_provider_type not in supported_providers:
|
||||
pytest.skip(f"{inference_provider_type} doesn't support rerank models")
|
||||
|
||||
|
||||
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
|
||||
"""
|
||||
Validate that a rerank response has the correct structure and ordering.
|
||||
|
||||
Args:
|
||||
response: The InferenceRerankResponse to validate
|
||||
items: The original items list that was ranked
|
||||
|
||||
Raises:
|
||||
AssertionError: If any validation fails
|
||||
"""
|
||||
seen = set()
|
||||
last_score = float("inf")
|
||||
for d in response:
|
||||
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
|
||||
assert d.index not in seen, f"Duplicate index {d.index} found"
|
||||
seen.add(d.index)
|
||||
assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}"
|
||||
assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}"
|
||||
last_score = d.relevance_score
|
||||
|
||||
|
||||
def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None:
|
||||
"""
|
||||
Validate that the expected most relevant item ranks first.
|
||||
|
||||
Args:
|
||||
response: The InferenceRerankResponse to validate
|
||||
items: The original items list that was ranked
|
||||
expected_first_item: The expected first item in the ranking
|
||||
|
||||
Raises:
|
||||
AssertionError: If any validation fails
|
||||
"""
|
||||
if not response:
|
||||
raise AssertionError("No ranking data returned in response")
|
||||
|
||||
actual_first_index = response[0].index
|
||||
actual_first_item = items[actual_first_index]
|
||||
assert actual_first_item == expected_first_item, (
|
||||
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items",
|
||||
[
|
||||
(DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]),
|
||||
(DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]),
|
||||
(DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]),
|
||||
],
|
||||
ids=[
|
||||
"string-query-string-items",
|
||||
"text-query-text-items",
|
||||
"mixed-content-1",
|
||||
"mixed-content-2",
|
||||
],
|
||||
)
|
||||
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
assert isinstance(response, list)
|
||||
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
|
||||
assert len(response) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items",
|
||||
[
|
||||
(DUMMY_IMAGE_URL, [DUMMY_STRING]),
|
||||
(DUMMY_IMAGE_BASE64, [DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_IMAGE_URL]),
|
||||
(DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||
],
|
||||
ids=[
|
||||
"image-query-url",
|
||||
"image-query-base64",
|
||||
"text-query-image-item",
|
||||
"mixed-content-1",
|
||||
"mixed-content-2",
|
||||
],
|
||||
)
|
||||
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
|
||||
error_type = (
|
||||
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||
)
|
||||
with pytest.raises(error_type):
|
||||
client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
else:
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
|
||||
max_num_results = 2
|
||||
|
||||
response = client_with_models.alpha.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
max_num_results=max_num_results,
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == max_num_results
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2]
|
||||
response = client_with_models.alpha.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
max_num_results=10, # Larger than items length
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) <= len(items) # Should return at most len(items)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items,expected_first_item",
|
||||
[
|
||||
(
|
||||
"What is a reranking model? ",
|
||||
[
|
||||
"A reranking model reranks a list of items based on the query. ",
|
||||
"Machine learning algorithms learn patterns from data. ",
|
||||
"Python is a programming language. ",
|
||||
],
|
||||
"A reranking model reranks a list of items based on the query. ",
|
||||
),
|
||||
(
|
||||
"What is C++?",
|
||||
[
|
||||
"Learning new things is interesting. ",
|
||||
"C++ is a programming language. ",
|
||||
"Books provide knowledge and entertainment. ",
|
||||
],
|
||||
"C++ is a programming language. ",
|
||||
),
|
||||
(
|
||||
"What are good learning habits? ",
|
||||
[
|
||||
"Cooking pasta is a fun activity. ",
|
||||
"Plants need water and sunlight. ",
|
||||
"Good learning habits include reading daily and taking notes. ",
|
||||
],
|
||||
"Good learning habits include reading daily and taking notes. ",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rerank_semantic_correctness(
|
||||
client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type
|
||||
):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
_validate_rerank_response(response, items)
|
||||
_validate_semantic_ranking(response, items, expected_first_item)
|
222
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
222
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
|
@ -0,0 +1,222 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status=200, json_data=None, text_data="OK"):
|
||||
self.status = status
|
||||
self._json_data = json_data or {"rankings": []}
|
||||
self._text_data = text_data
|
||||
|
||||
async def json(self):
|
||||
return self._json_data
|
||||
|
||||
async def text(self):
|
||||
return self._text_data
|
||||
|
||||
|
||||
class MockSession:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self.post_calls = []
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
self.post_calls.append((url, kwargs))
|
||||
|
||||
class PostContext:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.response
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
return PostContext(self.response)
|
||||
|
||||
|
||||
def create_adapter(config=None, rerank_endpoints=None):
|
||||
if config is None:
|
||||
config = NVIDIAConfig(api_key="test-key")
|
||||
|
||||
adapter = NVIDIAInferenceAdapter(config)
|
||||
|
||||
class MockModel:
|
||||
provider_resource_id = "test-model"
|
||||
metadata = {}
|
||||
|
||||
adapter.model_store = AsyncMock()
|
||||
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
|
||||
|
||||
if rerank_endpoints is not None:
|
||||
adapter._rerank_model_endpoints = rerank_endpoints
|
||||
|
||||
return adapter
|
||||
|
||||
|
||||
async def test_rerank_basic_functionality():
|
||||
adapter = create_adapter()
|
||||
mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]})
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"])
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].index == 0
|
||||
assert result.data[0].relevance_score == 0.5
|
||||
|
||||
url, kwargs = mock_session.post_calls[0]
|
||||
payload = kwargs["json"]
|
||||
assert payload["model"] == "test-model"
|
||||
assert payload["query"] == {"text": "test query"}
|
||||
assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}]
|
||||
|
||||
|
||||
async def test_missing_rankings_key():
|
||||
adapter = create_adapter()
|
||||
mock_session = MockSession(MockResponse(json_data={}))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
assert len(result.data) == 0
|
||||
|
||||
|
||||
async def test_hosted_with_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"}
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert url == "https://model.endpoint/rerank"
|
||||
|
||||
|
||||
async def test_hosted_without_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
|
||||
rerank_endpoints={}, # No endpoint mapping for test-model
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "https://integrate.api.nvidia.com" in url
|
||||
|
||||
|
||||
async def test_hosted_model_not_in_endpoint_mapping():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"}
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "https://integrate.api.nvidia.com" in url
|
||||
assert url != "https://other.endpoint/rerank"
|
||||
|
||||
|
||||
async def test_self_hosted_ignores_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
||||
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "http://localhost:8000" in url
|
||||
assert "model.endpoint/rerank" not in url
|
||||
|
||||
|
||||
async def test_max_num_results():
|
||||
adapter = create_adapter()
|
||||
rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}]
|
||||
mock_session = MockSession(MockResponse(json_data={"rankings": rankings}))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].index == 0
|
||||
assert result.data[0].relevance_score == 0.8
|
||||
|
||||
|
||||
async def test_http_error():
|
||||
adapter = create_adapter()
|
||||
mock_session = MockSession(MockResponse(status=500, text_data="Server Error"))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(ConnectionError, match="status 500.*Server Error"):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
|
||||
async def test_client_error():
|
||||
adapter = create_adapter()
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error")
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
|
||||
async def test_list_models_adds_rerank_models():
|
||||
"""Test that list_models adds rerank models to the dynamic model list."""
|
||||
adapter = create_adapter()
|
||||
adapter.__provider_id__ = "nvidia"
|
||||
|
||||
# Mock the list_models from the superclass to return some dynamic models
|
||||
base_models = [
|
||||
MagicMock(identifier="llm-1", model_type=ModelType.llm),
|
||||
MagicMock(identifier="embedding-1", model_type=ModelType.embedding),
|
||||
]
|
||||
|
||||
with patch.object(NVIDIAInferenceAdapter.__bases__[0], "list_models", return_value=base_models):
|
||||
result = await adapter.list_models()
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Check that the rerank models are added
|
||||
model_ids = [m.identifier for m in result]
|
||||
assert "nv-rerank-qa-mistral-4b:1" in model_ids
|
||||
assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids
|
||||
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids
|
||||
|
||||
rerank_models = [m for m in result if m.model_type == ModelType.rerank]
|
||||
|
||||
assert len(rerank_models) == 3
|
||||
|
||||
for rerank_model in rerank_models:
|
||||
assert rerank_model.provider_id == "nvidia"
|
||||
assert rerank_model.metadata == {}
|
||||
assert rerank_model.identifier in adapter._model_cache
|
|
@ -35,6 +35,40 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
|
|||
}
|
||||
|
||||
|
||||
class OpenAIMixinWithRerankImpl(OpenAIMixin):
|
||||
"""Test implementation with rerank model list"""
|
||||
|
||||
rerank_model_list = ["rerank-model-1", "rerank-model-2"]
|
||||
|
||||
def __init__(self):
|
||||
self.__provider_id__ = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
|
||||
class OpenAIMixinWithEmbeddingsAndRerankImpl(OpenAIMixin):
|
||||
"""Test implementation with both embedding model metadata and rerank model list"""
|
||||
|
||||
embedding_model_metadata = {
|
||||
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
}
|
||||
|
||||
rerank_model_list = ["rerank-model-1", "rerank-model-2"]
|
||||
|
||||
__provider_id__ = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin():
|
||||
"""Create a test instance of OpenAIMixin with mocked model_store"""
|
||||
|
@ -56,6 +90,18 @@ def mixin_with_embeddings():
|
|||
return OpenAIMixinWithEmbeddingsImpl()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_rerank():
|
||||
"""Create a test instance of OpenAIMixin with rerank model list"""
|
||||
return OpenAIMixinWithRerankImpl()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_embeddings_and_rerank():
|
||||
"""Create a test instance of OpenAIMixin with both embedding model metadata and rerank model list"""
|
||||
return OpenAIMixinWithEmbeddingsAndRerankImpl()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_models():
|
||||
"""Create multiple mock OpenAI model objects"""
|
||||
|
@ -107,6 +153,19 @@ def mock_client_context():
|
|||
return _mock_client_context
|
||||
|
||||
|
||||
def _assert_models_match_expected(actual_models, expected_models):
|
||||
"""Verify the models match expected attributes.
|
||||
|
||||
Args:
|
||||
actual_models: List of models to verify
|
||||
expected_models: Mapping of model identifier to expected attribute values
|
||||
"""
|
||||
for identifier, expected_attrs in expected_models.items():
|
||||
model = next(m for m in actual_models if m.identifier == identifier)
|
||||
for attr_name, expected_value in expected_attrs.items():
|
||||
assert getattr(model, attr_name) == expected_value
|
||||
|
||||
|
||||
class TestOpenAIMixinListModels:
|
||||
"""Test cases for the list_models method"""
|
||||
|
||||
|
@ -300,21 +359,113 @@ class TestOpenAIMixinEmbeddingModelMetadata:
|
|||
assert result is not None
|
||||
assert len(result) == 2
|
||||
|
||||
# Find the models in the result
|
||||
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small")
|
||||
llm_model = next(m for m in result if m.identifier == "gpt-4")
|
||||
expected_models = {
|
||||
"text-embedding-3-small": {
|
||||
"model_type": ModelType.embedding,
|
||||
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "text-embedding-3-small",
|
||||
},
|
||||
"gpt-4": {
|
||||
"model_type": ModelType.llm,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
# Check embedding model
|
||||
assert embedding_model.model_type == ModelType.embedding
|
||||
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
|
||||
assert embedding_model.provider_id == "test-provider"
|
||||
assert embedding_model.provider_resource_id == "text-embedding-3-small"
|
||||
_assert_models_match_expected(result, expected_models)
|
||||
|
||||
# Check LLM model
|
||||
assert llm_model.model_type == ModelType.llm
|
||||
assert llm_model.metadata == {} # No metadata for LLMs
|
||||
assert llm_model.provider_id == "test-provider"
|
||||
assert llm_model.provider_resource_id == "gpt-4"
|
||||
|
||||
class TestOpenAIMixinRerankModelList:
|
||||
"""Test cases for rerank_model_list attribute functionality"""
|
||||
|
||||
async def test_rerank_model_identified(self, mixin_with_rerank, mock_client_context):
|
||||
"""Test that models in rerank_model_list are correctly identified as rerank models"""
|
||||
# Create mock models: 1 rerank model and 1 LLM
|
||||
mock_rerank_model = MagicMock(id="rerank-model-1")
|
||||
mock_llm_model = MagicMock(id="gpt-4")
|
||||
mock_models = [mock_rerank_model, mock_llm_model]
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
async def mock_models_list():
|
||||
for model in mock_models:
|
||||
yield model
|
||||
|
||||
mock_client.models.list.return_value = mock_models_list()
|
||||
|
||||
with mock_client_context(mixin_with_rerank, mock_client):
|
||||
result = await mixin_with_rerank.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
|
||||
expected_models = {
|
||||
"rerank-model-1": {
|
||||
"model_type": ModelType.rerank,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "rerank-model-1",
|
||||
},
|
||||
"gpt-4": {
|
||||
"model_type": ModelType.llm,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
_assert_models_match_expected(result, expected_models)
|
||||
|
||||
|
||||
class TestOpenAIMixinMixedModelTypes:
|
||||
"""Test cases for mixed model types (LLM, embedding, rerank)"""
|
||||
|
||||
async def test_mixed_model_types_identification(self, mixin_with_embeddings_and_rerank, mock_client_context):
|
||||
"""Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata"""
|
||||
# Create mock models: 1 embedding, 1 rerank, 1 LLM
|
||||
mock_embedding_model = MagicMock(id="text-embedding-3-small")
|
||||
mock_rerank_model = MagicMock(id="rerank-model-1")
|
||||
mock_llm_model = MagicMock(id="gpt-4")
|
||||
mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model]
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
async def mock_models_list():
|
||||
for model in mock_models:
|
||||
yield model
|
||||
|
||||
mock_client.models.list.return_value = mock_models_list()
|
||||
|
||||
with mock_client_context(mixin_with_embeddings_and_rerank, mock_client):
|
||||
result = await mixin_with_embeddings_and_rerank.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 3
|
||||
|
||||
expected_models = {
|
||||
"text-embedding-3-small": {
|
||||
"model_type": ModelType.embedding,
|
||||
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "text-embedding-3-small",
|
||||
},
|
||||
"rerank-model-1": {
|
||||
"model_type": ModelType.rerank,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "rerank-model-1",
|
||||
},
|
||||
"gpt-4": {
|
||||
"model_type": ModelType.llm,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
_assert_models_match_expected(result, expected_models)
|
||||
|
||||
|
||||
class TestOpenAIMixinAllowedModels:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue