Add tests

This commit is contained in:
Jiayi 2025-09-04 18:08:35 -07:00
parent bab9d7aaea
commit d7cbeb4b8c
6 changed files with 345 additions and 5 deletions

View file

@ -1,9 +1,9 @@
--- ---
description: "Llama Stack Inference API for generating completions, chat completions, and embeddings. description: "Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search." - Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models rerank the documents by relevance." - Rerank models: these models rerank the documents by relevance."
sidebar_label: Inference sidebar_label: Inference
title: Inference title: Inference
@ -15,7 +15,7 @@ title: Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings. Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search. - Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models rerank the documents by relevance. - Rerank models: these models rerank the documents by relevance.

View file

@ -151,7 +151,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
for ranking in rankings: for ranking in rankings:
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"])) rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
# Apply max_num_results limit if specified # Apply max_num_results limit
if max_num_results is not None: if max_num_results is not None:
rerank_data = rerank_data[:max_num_results] rerank_data = rerank_data[:max_num_results]

View file

@ -120,6 +120,10 @@ def pytest_addoption(parser):
"--embedding-model", "--embedding-model",
help="comma-separated list of embedding models. Fixture name: embedding_model_id", help="comma-separated list of embedding models. Fixture name: embedding_model_id",
) )
parser.addoption(
"--rerank-model",
help="comma-separated list of rerank models. Fixture name: rerank_model_id",
)
parser.addoption( parser.addoption(
"--safety-shield", "--safety-shield",
help="comma-separated list of safety shields. Fixture name: shield_id", help="comma-separated list of safety shields. Fixture name: shield_id",
@ -198,6 +202,7 @@ def pytest_generate_tests(metafunc):
"shield_id": ("--safety-shield", "shield"), "shield_id": ("--safety-shield", "shield"),
"judge_model_id": ("--judge-model", "judge"), "judge_model_id": ("--judge-model", "judge"),
"embedding_dimension": ("--embedding-dimension", "dim"), "embedding_dimension": ("--embedding-dimension", "dim"),
"rerank_model_id": ("--rerank-model", "rerank"),
} }
# Collect all parameters and their values # Collect all parameters and their values

View file

@ -119,6 +119,7 @@ def client_with_models(
embedding_model_id, embedding_model_id,
embedding_dimension, embedding_dimension,
judge_model_id, judge_model_id,
rerank_model_id,
): ):
client = llama_stack_client client = llama_stack_client
@ -151,6 +152,13 @@ def client_with_models(
model_type="embedding", model_type="embedding",
metadata={"embedding_dimension": embedding_dimension or 384}, metadata={"embedding_dimension": embedding_dimension or 384},
) )
if rerank_model_id and rerank_model_id not in model_ids:
rerank_provider = providers[0]
client.models.register(
model_id=rerank_model_id,
provider_id=rerank_provider.provider_id,
model_type="rerank",
)
return client return client
@ -166,7 +174,7 @@ def model_providers(llama_stack_client):
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def skip_if_no_model(request): def skip_if_no_model(request):
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"] model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id", "rerank_model_id"]
test_func = request.node.function test_func = request.node.function
actual_params = inspect.signature(test_func).parameters.keys() actual_params = inspect.signature(test_func).parameters.keys()

View file

@ -0,0 +1,147 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
from llama_stack_client.types import RerankResponse
from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem,
ImageContentItemImage,
ImageContentItemImageURL,
TextContentItem,
)
from llama_stack.core.library_client import LlamaStackAsLibraryClient
# Test data
DUMMY_STRING = "string_1"
DUMMY_STRING2 = "string_2"
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
DUMMY_IMAGE_URL = ImageContentItem(
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
)
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
SUPPORTED_PROVIDERS = {"remote::nvidia"}
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
def _validate_rerank_response(response: RerankResponse, items: list) -> None:
"""
Validate that a rerank response has the correct structure and ordering.
Args:
response: The RerankResponse to validate
items: The original items list that was ranked
Raises:
AssertionError: If any validation fails
"""
seen = set()
last_score = float("inf")
for d in response.data:
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
assert d.index not in seen, f"Duplicate index {d.index} found"
seen.add(d.index)
assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}"
assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}"
last_score = d.relevance_score
@pytest.mark.parametrize(
"query,items",
[
(DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]),
(DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]),
(DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]),
],
ids=[
"string-query-string-items",
"text-query-text-items",
"mixed-content-1",
"mixed-content-2",
],
)
def test_rerank_text(llama_stack_client, rerank_model_id, query, items, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items)
_validate_rerank_response(response, items)
@pytest.mark.parametrize(
"query,items",
[
(DUMMY_IMAGE_URL, [DUMMY_STRING]),
(DUMMY_IMAGE_BASE64, [DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_IMAGE_URL]),
(DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
],
ids=[
"image-query-url",
"image-query-base64",
"text-query-image-item",
"mixed-content-1",
"mixed-content-2",
],
)
def test_rerank_image(llama_stack_client, rerank_model_id, query, items, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
error_type = (
ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
)
with pytest.raises(error_type):
llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
else:
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items)
_validate_rerank_response(response, items)
def test_rerank_max_results(llama_stack_client, rerank_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
max_num_results = 2
response = llama_stack_client.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
max_num_results=max_num_results,
)
assert isinstance(response, RerankResponse)
assert len(response.data) == max_num_results
_validate_rerank_response(response, items)
def test_rerank_max_results_larger_than_items(llama_stack_client, rerank_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank yet")
items = [DUMMY_STRING, DUMMY_STRING2]
response = llama_stack_client.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
max_num_results=10, # Larger than items length
)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items) # Should return at most len(items)

View file

@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, patch
import aiohttp
import pytest
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
class MockResponse:
def __init__(self, status=200, json_data=None, text_data="OK"):
self.status = status
self._json_data = json_data or {"rankings": []}
self._text_data = text_data
async def json(self):
return self._json_data
async def text(self):
return self._text_data
class MockSession:
def __init__(self, response):
self.response = response
self.post_calls = []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return False
def post(self, url, **kwargs):
self.post_calls.append((url, kwargs))
class PostContext:
def __init__(self, response):
self.response = response
async def __aenter__(self):
return self.response
async def __aexit__(self, exc_type, exc_val, exc_tb):
return False
return PostContext(self.response)
def create_adapter(config=None, model_metadata=None):
if config is None:
config = NVIDIAConfig(api_key="test-key")
adapter = NVIDIAInferenceAdapter(config)
class MockModel:
provider_resource_id = "test-model"
metadata = model_metadata or {}
adapter.model_store = AsyncMock()
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
return adapter
@pytest.mark.asyncio
async def test_rerank_basic_functionality():
adapter = create_adapter()
mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]})
mock_session = MockSession(mock_response)
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"])
assert len(result.data) == 1
assert result.data[0].index == 0
assert result.data[0].relevance_score == 0.5
url, kwargs = mock_session.post_calls[0]
payload = kwargs["json"]
assert payload["model"] == "test-model"
assert payload["query"] == {"text": "test query"}
assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}]
@pytest.mark.asyncio
async def test_missing_rankings_key():
adapter = create_adapter()
mock_session = MockSession(MockResponse(json_data={}))
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="q", items=["a"])
assert len(result.data) == 0
@pytest.mark.asyncio
async def test_hosted_with_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), model_metadata={"endpoint": "https://model.endpoint/rerank"}
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert url == "https://model.endpoint/rerank"
@pytest.mark.asyncio
async def test_hosted_without_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
model_metadata={}, # No "endpoint" key
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "https://integrate.api.nvidia.com" in url
@pytest.mark.asyncio
async def test_self_hosted_ignores_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
model_metadata={"endpoint": "https://model.endpoint/rerank"}, # This should be ignored.
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "http://localhost:8000" in url
assert "model.endpoint/rerank" not in url
@pytest.mark.asyncio
async def test_max_num_results():
adapter = create_adapter()
rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}]
mock_session = MockSession(MockResponse(json_data={"rankings": rankings}))
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1)
assert len(result.data) == 1
assert result.data[0].index == 0
assert result.data[0].relevance_score == 0.8
@pytest.mark.asyncio
async def test_http_error():
adapter = create_adapter()
mock_session = MockSession(MockResponse(status=500, text_data="Server Error"))
with patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(ConnectionError, match="status 500.*Server Error"):
await adapter.rerank(model="test-model", query="q", items=["a"])
@pytest.mark.asyncio
async def test_client_error():
adapter = create_adapter()
mock_session = AsyncMock()
mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error")
with patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
await adapter.rerank(model="test-model", query="q", items=["a"])