mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Add rerank models to the dynamic model list; Fix integration tests
This commit is contained in:
parent
3538477070
commit
816b68fdc7
8 changed files with 247 additions and 25 deletions
|
@ -5,6 +5,7 @@ description: "Llama Stack Inference API for generating completions, chat complet
|
||||||
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
|
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
|
||||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
- Rerank models: these models rerank the documents by relevance."
|
- Rerank models: these models rerank the documents by relevance."
|
||||||
|
|
||||||
sidebar_label: Inference
|
sidebar_label: Inference
|
||||||
title: Inference
|
title: Inference
|
||||||
---
|
---
|
||||||
|
|
|
@ -204,6 +204,6 @@ rerank_response = client.inference.rerank(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, result in enumerate(rerank_response.data):
|
for i, result in enumerate(rerank_response):
|
||||||
print(f"{i+1}. [Index: {result.index}, Score: {result.relevance_score:.3f}]")
|
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
|
||||||
```
|
```
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletionContentPartImageParam,
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
|
@ -51,6 +52,18 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
||||||
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
|
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rerank_model_list = [
|
||||||
|
"nv-rerank-qa-mistral-4b:1",
|
||||||
|
"nvidia/nv-rerankqa-mistral-4b-v3",
|
||||||
|
"nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||||
|
]
|
||||||
|
|
||||||
|
_rerank_model_endpoints = {
|
||||||
|
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||||
|
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
|
||||||
|
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, config: NVIDIAConfig) -> None:
|
def __init__(self, config: NVIDIAConfig) -> None:
|
||||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||||
|
|
||||||
|
@ -69,6 +82,8 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
||||||
# "Consider removing the api_key from the configuration."
|
# "Consider removing the api_key from the configuration."
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
|
@ -87,6 +102,30 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
||||||
"""
|
"""
|
||||||
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||||
|
|
||||||
|
async def list_models(self) -> list[Model] | None:
|
||||||
|
"""
|
||||||
|
List available NVIDIA models by combining:
|
||||||
|
1. Dynamic models from https://integrate.api.nvidia.com/v1/models
|
||||||
|
2. Static rerank models (which use different API endpoints)
|
||||||
|
"""
|
||||||
|
models = await super().list_models() or []
|
||||||
|
|
||||||
|
existing_ids = {m.identifier for m in models}
|
||||||
|
for model_id, _ in self._rerank_model_endpoints.items():
|
||||||
|
if self.allowed_models and model_id not in self.allowed_models:
|
||||||
|
continue
|
||||||
|
if model_id not in existing_ids:
|
||||||
|
model = Model(
|
||||||
|
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||||
|
provider_resource_id=model_id,
|
||||||
|
identifier=model_id,
|
||||||
|
model_type=ModelType.rerank,
|
||||||
|
)
|
||||||
|
models.append(model)
|
||||||
|
self._model_cache[model_id] = model
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
async def rerank(
|
async def rerank(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -63,6 +63,10 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
|
||||||
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
|
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
|
||||||
embedding_model_metadata: dict[str, dict[str, int]] = {}
|
embedding_model_metadata: dict[str, dict[str, int]] = {}
|
||||||
|
|
||||||
|
# List of rerank model IDs for this provider
|
||||||
|
# Can be set by subclasses or instances to provide rerank models
|
||||||
|
rerank_model_list: list[str] = []
|
||||||
|
|
||||||
# Cache of available models keyed by model ID
|
# Cache of available models keyed by model ID
|
||||||
# This is set in list_models() and used in check_model_availability()
|
# This is set in list_models() and used in check_model_availability()
|
||||||
_model_cache: dict[str, Model] = {}
|
_model_cache: dict[str, Model] = {}
|
||||||
|
@ -400,6 +404,14 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
elif m.id in self.rerank_model_list:
|
||||||
|
# This is a rerank model
|
||||||
|
model = Model(
|
||||||
|
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||||
|
provider_resource_id=m.id,
|
||||||
|
identifier=m.id,
|
||||||
|
model_type=ModelType.rerank,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# This is an LLM
|
# This is an LLM
|
||||||
model = Model(
|
model = Model(
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||||
from llama_stack_client.types import RerankResponse
|
from llama_stack_client.types import InferenceRerankResponse
|
||||||
from llama_stack_client.types.shared.interleaved_content import (
|
from llama_stack_client.types.shared.interleaved_content import (
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
ImageContentItemImage,
|
ImageContentItemImage,
|
||||||
|
@ -30,12 +30,12 @@ SUPPORTED_PROVIDERS = {"remote::nvidia"}
|
||||||
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
|
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
|
||||||
|
|
||||||
|
|
||||||
def _validate_rerank_response(response: RerankResponse, items: list) -> None:
|
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
|
||||||
"""
|
"""
|
||||||
Validate that a rerank response has the correct structure and ordering.
|
Validate that a rerank response has the correct structure and ordering.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: The RerankResponse to validate
|
response: The InferenceRerankResponse to validate
|
||||||
items: The original items list that was ranked
|
items: The original items list that was ranked
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -43,7 +43,7 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None:
|
||||||
"""
|
"""
|
||||||
seen = set()
|
seen = set()
|
||||||
last_score = float("inf")
|
last_score = float("inf")
|
||||||
for d in response.data:
|
for d in response:
|
||||||
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
|
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
|
||||||
assert d.index not in seen, f"Duplicate index {d.index} found"
|
assert d.index not in seen, f"Duplicate index {d.index} found"
|
||||||
seen.add(d.index)
|
seen.add(d.index)
|
||||||
|
@ -52,22 +52,22 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None:
|
||||||
last_score = d.relevance_score
|
last_score = d.relevance_score
|
||||||
|
|
||||||
|
|
||||||
def _validate_semantic_ranking(response: RerankResponse, items: list, expected_first_item: str) -> None:
|
def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None:
|
||||||
"""
|
"""
|
||||||
Validate that the expected most relevant item ranks first.
|
Validate that the expected most relevant item ranks first.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: The RerankResponse to validate
|
response: The InferenceRerankResponse to validate
|
||||||
items: The original items list that was ranked
|
items: The original items list that was ranked
|
||||||
expected_first_item: The expected first item in the ranking
|
expected_first_item: The expected first item in the ranking
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AssertionError: If any validation fails
|
AssertionError: If any validation fails
|
||||||
"""
|
"""
|
||||||
if not response.data:
|
if not response:
|
||||||
raise AssertionError("No ranking data returned in response")
|
raise AssertionError("No ranking data returned in response")
|
||||||
|
|
||||||
actual_first_index = response.data[0].index
|
actual_first_index = response[0].index
|
||||||
actual_first_item = items[actual_first_index]
|
actual_first_item = items[actual_first_index]
|
||||||
assert actual_first_item == expected_first_item, (
|
assert actual_first_item == expected_first_item, (
|
||||||
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
|
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
|
||||||
|
@ -94,8 +94,9 @@ def test_rerank_text(client_with_models, rerank_model_id, query, items, inferenc
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||||
|
|
||||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||||
assert isinstance(response, RerankResponse)
|
assert isinstance(response, list)
|
||||||
assert len(response.data) <= len(items)
|
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
|
||||||
|
assert len(response) <= len(items)
|
||||||
_validate_rerank_response(response, items)
|
_validate_rerank_response(response, items)
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,8 +130,8 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen
|
||||||
else:
|
else:
|
||||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||||
|
|
||||||
assert isinstance(response, RerankResponse)
|
assert isinstance(response, list)
|
||||||
assert len(response.data) <= len(items)
|
assert len(response) <= len(items)
|
||||||
_validate_rerank_response(response, items)
|
_validate_rerank_response(response, items)
|
||||||
|
|
||||||
|
|
||||||
|
@ -148,8 +149,8 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi
|
||||||
max_num_results=max_num_results,
|
max_num_results=max_num_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(response, RerankResponse)
|
assert isinstance(response, list)
|
||||||
assert len(response.data) == max_num_results
|
assert len(response) == max_num_results
|
||||||
_validate_rerank_response(response, items)
|
_validate_rerank_response(response, items)
|
||||||
|
|
||||||
|
|
||||||
|
@ -165,8 +166,8 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i
|
||||||
max_num_results=10, # Larger than items length
|
max_num_results=10, # Larger than items length
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(response, RerankResponse)
|
assert isinstance(response, list)
|
||||||
assert len(response.data) <= len(items) # Should return at most len(items)
|
assert len(response) <= len(items) # Should return at most len(items)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
@ -4,11 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
||||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
||||||
|
|
||||||
|
@ -170,3 +171,35 @@ async def test_client_error():
|
||||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
|
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
|
||||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_models_adds_rerank_models():
|
||||||
|
"""Test that list_models adds rerank models to the dynamic model list."""
|
||||||
|
adapter = create_adapter()
|
||||||
|
adapter.__provider_id__ = "nvidia"
|
||||||
|
|
||||||
|
# Mock the list_models from the superclass to return some dynamic models
|
||||||
|
base_models = [
|
||||||
|
MagicMock(identifier="llm-1", model_type=ModelType.llm),
|
||||||
|
MagicMock(identifier="embedding-1", model_type=ModelType.embedding),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(NVIDIAInferenceAdapter.__bases__[0], "list_models", return_value=base_models):
|
||||||
|
result = await adapter.list_models()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
# Check that the rerank models are added
|
||||||
|
model_ids = [m.identifier for m in result]
|
||||||
|
assert "nv-rerank-qa-mistral-4b:1" in model_ids
|
||||||
|
assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids
|
||||||
|
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids
|
||||||
|
|
||||||
|
rerank_models = [m for m in result if m.model_type == ModelType.rerank]
|
||||||
|
|
||||||
|
assert len(rerank_models) == 3
|
||||||
|
|
||||||
|
for rerank_model in rerank_models:
|
||||||
|
assert rerank_model.provider_id == "nvidia"
|
||||||
|
assert rerank_model.metadata == {}
|
||||||
|
assert rerank_model.identifier in adapter._model_cache
|
||||||
|
|
|
@ -35,6 +35,40 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIMixinWithRerankImpl(OpenAIMixin):
|
||||||
|
"""Test implementation with rerank model list"""
|
||||||
|
|
||||||
|
rerank_model_list = ["rerank-model-1", "rerank-model-2"]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.__provider_id__ = "test-provider"
|
||||||
|
|
||||||
|
def get_api_key(self) -> str:
|
||||||
|
raise NotImplementedError("This method should be mocked in tests")
|
||||||
|
|
||||||
|
def get_base_url(self) -> str:
|
||||||
|
raise NotImplementedError("This method should be mocked in tests")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIMixinWithEmbeddingsAndRerankImpl(OpenAIMixin):
|
||||||
|
"""Test implementation with both embedding model metadata and rerank model list"""
|
||||||
|
|
||||||
|
embedding_model_metadata = {
|
||||||
|
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
||||||
|
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
|
||||||
|
}
|
||||||
|
|
||||||
|
rerank_model_list = ["rerank-model-1", "rerank-model-2"]
|
||||||
|
|
||||||
|
__provider_id__ = "test-provider"
|
||||||
|
|
||||||
|
def get_api_key(self) -> str:
|
||||||
|
raise NotImplementedError("This method should be mocked in tests")
|
||||||
|
|
||||||
|
def get_base_url(self) -> str:
|
||||||
|
raise NotImplementedError("This method should be mocked in tests")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mixin():
|
def mixin():
|
||||||
"""Create a test instance of OpenAIMixin with mocked model_store"""
|
"""Create a test instance of OpenAIMixin with mocked model_store"""
|
||||||
|
@ -56,6 +90,18 @@ def mixin_with_embeddings():
|
||||||
return OpenAIMixinWithEmbeddingsImpl()
|
return OpenAIMixinWithEmbeddingsImpl()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mixin_with_rerank():
|
||||||
|
"""Create a test instance of OpenAIMixin with rerank model list"""
|
||||||
|
return OpenAIMixinWithRerankImpl()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mixin_with_embeddings_and_rerank():
|
||||||
|
"""Create a test instance of OpenAIMixin with both embedding model metadata and rerank model list"""
|
||||||
|
return OpenAIMixinWithEmbeddingsAndRerankImpl()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_models():
|
def mock_models():
|
||||||
"""Create multiple mock OpenAI model objects"""
|
"""Create multiple mock OpenAI model objects"""
|
||||||
|
@ -317,6 +363,96 @@ class TestOpenAIMixinEmbeddingModelMetadata:
|
||||||
assert llm_model.provider_resource_id == "gpt-4"
|
assert llm_model.provider_resource_id == "gpt-4"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIMixinRerankModelList:
|
||||||
|
"""Test cases for rerank_model_list attribute functionality"""
|
||||||
|
|
||||||
|
async def test_rerank_model_identified(self, mixin_with_rerank, mock_client_context):
|
||||||
|
"""Test that models in rerank_model_list are correctly identified as rerank models"""
|
||||||
|
# Create mock models: 1 rerank model and 1 LLM
|
||||||
|
mock_rerank_model = MagicMock(id="rerank-model-1")
|
||||||
|
mock_llm_model = MagicMock(id="gpt-4")
|
||||||
|
mock_models = [mock_rerank_model, mock_llm_model]
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
|
||||||
|
async def mock_models_list():
|
||||||
|
for model in mock_models:
|
||||||
|
yield model
|
||||||
|
|
||||||
|
mock_client.models.list.return_value = mock_models_list()
|
||||||
|
|
||||||
|
with mock_client_context(mixin_with_rerank, mock_client):
|
||||||
|
result = await mixin_with_rerank.list_models()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
# Find the models in the result
|
||||||
|
rerank_model = next(m for m in result if m.identifier == "rerank-model-1")
|
||||||
|
llm_model = next(m for m in result if m.identifier == "gpt-4")
|
||||||
|
|
||||||
|
# Check rerank model
|
||||||
|
assert rerank_model.model_type == ModelType.rerank
|
||||||
|
assert rerank_model.metadata == {} # No metadata for rerank models
|
||||||
|
assert rerank_model.provider_id == "test-provider"
|
||||||
|
assert rerank_model.provider_resource_id == "rerank-model-1"
|
||||||
|
|
||||||
|
# Check LLM model
|
||||||
|
assert llm_model.model_type == ModelType.llm
|
||||||
|
assert llm_model.metadata == {} # No metadata for LLMs
|
||||||
|
assert llm_model.provider_id == "test-provider"
|
||||||
|
assert llm_model.provider_resource_id == "gpt-4"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIMixinMixedModelTypes:
|
||||||
|
"""Test cases for mixed model types (LLM, embedding, rerank)"""
|
||||||
|
|
||||||
|
async def test_mixed_model_types_identification(self, mixin_with_embeddings_and_rerank, mock_client_context):
|
||||||
|
"""Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata"""
|
||||||
|
# Create mock models: 1 embedding, 1 rerank, 1 LLM
|
||||||
|
mock_embedding_model = MagicMock(id="text-embedding-3-small")
|
||||||
|
mock_rerank_model = MagicMock(id="rerank-model-1")
|
||||||
|
mock_llm_model = MagicMock(id="gpt-4")
|
||||||
|
mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model]
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
|
||||||
|
async def mock_models_list():
|
||||||
|
for model in mock_models:
|
||||||
|
yield model
|
||||||
|
|
||||||
|
mock_client.models.list.return_value = mock_models_list()
|
||||||
|
|
||||||
|
with mock_client_context(mixin_with_embeddings_and_rerank, mock_client):
|
||||||
|
result = await mixin_with_embeddings_and_rerank.list_models()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) == 3
|
||||||
|
|
||||||
|
# Find the models in the result
|
||||||
|
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small")
|
||||||
|
rerank_model = next(m for m in result if m.identifier == "rerank-model-1")
|
||||||
|
llm_model = next(m for m in result if m.identifier == "gpt-4")
|
||||||
|
|
||||||
|
# Check embedding model
|
||||||
|
assert embedding_model.model_type == ModelType.embedding
|
||||||
|
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
|
||||||
|
assert embedding_model.provider_id == "test-provider"
|
||||||
|
assert embedding_model.provider_resource_id == "text-embedding-3-small"
|
||||||
|
|
||||||
|
# Check rerank model
|
||||||
|
assert rerank_model.model_type == ModelType.rerank
|
||||||
|
assert rerank_model.metadata == {} # No metadata for rerank models
|
||||||
|
assert rerank_model.provider_id == "test-provider"
|
||||||
|
assert rerank_model.provider_resource_id == "rerank-model-1"
|
||||||
|
|
||||||
|
# Check LLM model
|
||||||
|
assert llm_model.model_type == ModelType.llm
|
||||||
|
assert llm_model.metadata == {} # No metadata for LLMs
|
||||||
|
assert llm_model.provider_id == "test-provider"
|
||||||
|
assert llm_model.provider_resource_id == "gpt-4"
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIMixinAllowedModels:
|
class TestOpenAIMixinAllowedModels:
|
||||||
"""Test cases for allowed_models filtering functionality"""
|
"""Test cases for allowed_models filtering functionality"""
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue