llama-stack-mirror/tests/unit/providers/nvidia/test_rerank_inference.py
2025-10-03 09:56:20 -07:00

222 lines
7.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
import pytest
from llama_stack.apis.models import ModelType
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
class MockResponse:
def __init__(self, status=200, json_data=None, text_data="OK"):
self.status = status
self._json_data = json_data or {"rankings": []}
self._text_data = text_data
async def json(self):
return self._json_data
async def text(self):
return self._text_data
class MockSession:
def __init__(self, response):
self.response = response
self.post_calls = []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return False
def post(self, url, **kwargs):
self.post_calls.append((url, kwargs))
class PostContext:
def __init__(self, response):
self.response = response
async def __aenter__(self):
return self.response
async def __aexit__(self, exc_type, exc_val, exc_tb):
return False
return PostContext(self.response)
def create_adapter(config=None, rerank_endpoints=None):
if config is None:
config = NVIDIAConfig(api_key="test-key")
adapter = NVIDIAInferenceAdapter(config)
class MockModel:
provider_resource_id = "test-model"
metadata = {}
adapter.model_store = AsyncMock()
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
if rerank_endpoints is not None:
adapter._rerank_model_endpoints = rerank_endpoints
return adapter
async def test_rerank_basic_functionality():
adapter = create_adapter()
mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]})
mock_session = MockSession(mock_response)
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"])
assert len(result.data) == 1
assert result.data[0].index == 0
assert result.data[0].relevance_score == 0.5
url, kwargs = mock_session.post_calls[0]
payload = kwargs["json"]
assert payload["model"] == "test-model"
assert payload["query"] == {"text": "test query"}
assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}]
async def test_missing_rankings_key():
adapter = create_adapter()
mock_session = MockSession(MockResponse(json_data={}))
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="q", items=["a"])
assert len(result.data) == 0
async def test_hosted_with_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"}
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert url == "https://model.endpoint/rerank"
async def test_hosted_without_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
rerank_endpoints={}, # No endpoint mapping for test-model
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "https://integrate.api.nvidia.com" in url
async def test_hosted_model_not_in_endpoint_mapping():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"}
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "https://integrate.api.nvidia.com" in url
assert url != "https://other.endpoint/rerank"
async def test_self_hosted_ignores_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "http://localhost:8000" in url
assert "model.endpoint/rerank" not in url
async def test_max_num_results():
adapter = create_adapter()
rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}]
mock_session = MockSession(MockResponse(json_data={"rankings": rankings}))
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1)
assert len(result.data) == 1
assert result.data[0].index == 0
assert result.data[0].relevance_score == 0.8
async def test_http_error():
adapter = create_adapter()
mock_session = MockSession(MockResponse(status=500, text_data="Server Error"))
with patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(ConnectionError, match="status 500.*Server Error"):
await adapter.rerank(model="test-model", query="q", items=["a"])
async def test_client_error():
adapter = create_adapter()
mock_session = AsyncMock()
mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error")
with patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
await adapter.rerank(model="test-model", query="q", items=["a"])
async def test_list_models_adds_rerank_models():
"""Test that list_models adds rerank models to the dynamic model list."""
adapter = create_adapter()
adapter.__provider_id__ = "nvidia"
# Mock the list_models from the superclass to return some dynamic models
base_models = [
MagicMock(identifier="llm-1", model_type=ModelType.llm),
MagicMock(identifier="embedding-1", model_type=ModelType.embedding),
]
with patch.object(NVIDIAInferenceAdapter.__bases__[0], "list_models", return_value=base_models):
result = await adapter.list_models()
assert result is not None
# Check that the rerank models are added
model_ids = [m.identifier for m in result]
assert "nv-rerank-qa-mistral-4b:1" in model_ids
assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids
rerank_models = [m for m in result if m.model_type == ModelType.rerank]
assert len(rerank_models) == 3
for rerank_model in rerank_models:
assert rerank_model.provider_id == "nvidia"
assert rerank_model.metadata == {}
assert rerank_model.identifier in adapter._model_cache