mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Add rerank models to the dynamic model list; Fix integration tests
This commit is contained in:
parent
3538477070
commit
816b68fdc7
8 changed files with 247 additions and 25 deletions
|
@ -4,11 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
||||
|
||||
|
@ -170,3 +171,35 @@ async def test_client_error():
|
|||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
|
||||
async def test_list_models_adds_rerank_models():
|
||||
"""Test that list_models adds rerank models to the dynamic model list."""
|
||||
adapter = create_adapter()
|
||||
adapter.__provider_id__ = "nvidia"
|
||||
|
||||
# Mock the list_models from the superclass to return some dynamic models
|
||||
base_models = [
|
||||
MagicMock(identifier="llm-1", model_type=ModelType.llm),
|
||||
MagicMock(identifier="embedding-1", model_type=ModelType.embedding),
|
||||
]
|
||||
|
||||
with patch.object(NVIDIAInferenceAdapter.__bases__[0], "list_models", return_value=base_models):
|
||||
result = await adapter.list_models()
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Check that the rerank models are added
|
||||
model_ids = [m.identifier for m in result]
|
||||
assert "nv-rerank-qa-mistral-4b:1" in model_ids
|
||||
assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids
|
||||
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids
|
||||
|
||||
rerank_models = [m for m in result if m.model_type == ModelType.rerank]
|
||||
|
||||
assert len(rerank_models) == 3
|
||||
|
||||
for rerank_model in rerank_models:
|
||||
assert rerank_model.provider_id == "nvidia"
|
||||
assert rerank_model.metadata == {}
|
||||
assert rerank_model.identifier in adapter._model_cache
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue