Add tests

This commit is contained in:
Jiayi 2025-09-04 18:08:35 -07:00
parent bab9d7aaea
commit d7cbeb4b8c
6 changed files with 345 additions and 5 deletions

View file

@ -1,9 +1,9 @@
---
description: "Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search."
- Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models rerank the documents by relevance."
sidebar_label: Inference
title: Inference
@ -15,7 +15,7 @@ title: Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models rerank the documents by relevance.

View file

@ -151,7 +151,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
for ranking in rankings:
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
# Apply max_num_results limit if specified
# Apply max_num_results limit
if max_num_results is not None:
rerank_data = rerank_data[:max_num_results]

View file

@ -120,6 +120,10 @@ def pytest_addoption(parser):
"--embedding-model",
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
)
parser.addoption(
"--rerank-model",
help="comma-separated list of rerank models. Fixture name: rerank_model_id",
)
parser.addoption(
"--safety-shield",
help="comma-separated list of safety shields. Fixture name: shield_id",
@ -198,6 +202,7 @@ def pytest_generate_tests(metafunc):
"shield_id": ("--safety-shield", "shield"),
"judge_model_id": ("--judge-model", "judge"),
"embedding_dimension": ("--embedding-dimension", "dim"),
"rerank_model_id": ("--rerank-model", "rerank"),
}
# Collect all parameters and their values

View file

@ -119,6 +119,7 @@ def client_with_models(
embedding_model_id,
embedding_dimension,
judge_model_id,
rerank_model_id,
):
client = llama_stack_client
@ -151,6 +152,13 @@ def client_with_models(
model_type="embedding",
metadata={"embedding_dimension": embedding_dimension or 384},
)
if rerank_model_id and rerank_model_id not in model_ids:
rerank_provider = providers[0]
client.models.register(
model_id=rerank_model_id,
provider_id=rerank_provider.provider_id,
model_type="rerank",
)
return client
@ -166,7 +174,7 @@ def model_providers(llama_stack_client):
@pytest.fixture(autouse=True)
def skip_if_no_model(request):
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"]
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id", "rerank_model_id"]
test_func = request.node.function
actual_params = inspect.signature(test_func).parameters.keys()

View file

@ -0,0 +1,147 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
from llama_stack_client.types import RerankResponse
from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem,
ImageContentItemImage,
ImageContentItemImageURL,
TextContentItem,
)
from llama_stack.core.library_client import LlamaStackAsLibraryClient
# Test data
DUMMY_STRING = "string_1"
DUMMY_STRING2 = "string_2"
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
DUMMY_IMAGE_URL = ImageContentItem(
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
)
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
SUPPORTED_PROVIDERS = {"remote::nvidia"}
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
def _validate_rerank_response(response: RerankResponse, items: list) -> None:
"""
Validate that a rerank response has the correct structure and ordering.
Args:
response: The RerankResponse to validate
items: The original items list that was ranked
Raises:
AssertionError: If any validation fails
"""
seen = set()
last_score = float("inf")
for d in response.data:
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
assert d.index not in seen, f"Duplicate index {d.index} found"
seen.add(d.index)
assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}"
assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}"
last_score = d.relevance_score
@pytest.mark.parametrize(
"query,items",
[
(DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]),
(DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]),
(DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]),
],
ids=[
"string-query-string-items",
"text-query-text-items",
"mixed-content-1",
"mixed-content-2",
],
)
def test_rerank_text(llama_stack_client, rerank_model_id, query, items, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items)
_validate_rerank_response(response, items)
@pytest.mark.parametrize(
"query,items",
[
(DUMMY_IMAGE_URL, [DUMMY_STRING]),
(DUMMY_IMAGE_BASE64, [DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_IMAGE_URL]),
(DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
],
ids=[
"image-query-url",
"image-query-base64",
"text-query-image-item",
"mixed-content-1",
"mixed-content-2",
],
)
def test_rerank_image(llama_stack_client, rerank_model_id, query, items, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
error_type = (
ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
)
with pytest.raises(error_type):
llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
else:
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items)
_validate_rerank_response(response, items)
def test_rerank_max_results(llama_stack_client, rerank_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
max_num_results = 2
response = llama_stack_client.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
max_num_results=max_num_results,
)
assert isinstance(response, RerankResponse)
assert len(response.data) == max_num_results
_validate_rerank_response(response, items)
def test_rerank_max_results_larger_than_items(llama_stack_client, rerank_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank yet")
items = [DUMMY_STRING, DUMMY_STRING2]
response = llama_stack_client.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
max_num_results=10, # Larger than items length
)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items) # Should return at most len(items)

View file

@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, patch
import aiohttp
import pytest
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
class MockResponse:
def __init__(self, status=200, json_data=None, text_data="OK"):
self.status = status
self._json_data = json_data or {"rankings": []}
self._text_data = text_data
async def json(self):
return self._json_data
async def text(self):
return self._text_data
class MockSession:
def __init__(self, response):
self.response = response
self.post_calls = []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return False
def post(self, url, **kwargs):
self.post_calls.append((url, kwargs))
class PostContext:
def __init__(self, response):
self.response = response
async def __aenter__(self):
return self.response
async def __aexit__(self, exc_type, exc_val, exc_tb):
return False
return PostContext(self.response)
def create_adapter(config=None, model_metadata=None):
if config is None:
config = NVIDIAConfig(api_key="test-key")
adapter = NVIDIAInferenceAdapter(config)
class MockModel:
provider_resource_id = "test-model"
metadata = model_metadata or {}
adapter.model_store = AsyncMock()
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
return adapter
@pytest.mark.asyncio
async def test_rerank_basic_functionality():
adapter = create_adapter()
mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]})
mock_session = MockSession(mock_response)
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"])
assert len(result.data) == 1
assert result.data[0].index == 0
assert result.data[0].relevance_score == 0.5
url, kwargs = mock_session.post_calls[0]
payload = kwargs["json"]
assert payload["model"] == "test-model"
assert payload["query"] == {"text": "test query"}
assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}]
@pytest.mark.asyncio
async def test_missing_rankings_key():
adapter = create_adapter()
mock_session = MockSession(MockResponse(json_data={}))
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="q", items=["a"])
assert len(result.data) == 0
@pytest.mark.asyncio
async def test_hosted_with_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), model_metadata={"endpoint": "https://model.endpoint/rerank"}
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert url == "https://model.endpoint/rerank"
@pytest.mark.asyncio
async def test_hosted_without_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
model_metadata={}, # No "endpoint" key
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "https://integrate.api.nvidia.com" in url
@pytest.mark.asyncio
async def test_self_hosted_ignores_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
model_metadata={"endpoint": "https://model.endpoint/rerank"}, # This should be ignored.
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "http://localhost:8000" in url
assert "model.endpoint/rerank" not in url
@pytest.mark.asyncio
async def test_max_num_results():
adapter = create_adapter()
rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}]
mock_session = MockSession(MockResponse(json_data={"rankings": rankings}))
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1)
assert len(result.data) == 1
assert result.data[0].index == 0
assert result.data[0].relevance_score == 0.8
@pytest.mark.asyncio
async def test_http_error():
adapter = create_adapter()
mock_session = MockSession(MockResponse(status=500, text_data="Server Error"))
with patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(ConnectionError, match="status 500.*Server Error"):
await adapter.rerank(model="test-model", query="q", items=["a"])
@pytest.mark.asyncio
async def test_client_error():
adapter = create_adapter()
mock_session = AsyncMock()
mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error")
with patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
await adapter.rerank(model="test-model", query="q", items=["a"])