mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Add tests
This commit is contained in:
parent
bab9d7aaea
commit
d7cbeb4b8c
6 changed files with 345 additions and 5 deletions
|
@ -1,9 +1,9 @@
|
||||||
---
|
---
|
||||||
description: "Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
description: "Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||||
|
|
||||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
This API provides the raw interface to the underlying models. Three kinds of models are supported:
|
||||||
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
|
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
|
||||||
- Embedding models: these models generate embeddings to be used for semantic search."
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
- Rerank models: these models rerank the documents by relevance."
|
- Rerank models: these models rerank the documents by relevance."
|
||||||
sidebar_label: Inference
|
sidebar_label: Inference
|
||||||
title: Inference
|
title: Inference
|
||||||
|
@ -15,7 +15,7 @@ title: Inference
|
||||||
|
|
||||||
Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||||
|
|
||||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
This API provides the raw interface to the underlying models. Three kinds of models are supported:
|
||||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
- Rerank models: these models rerank the documents by relevance.
|
- Rerank models: these models rerank the documents by relevance.
|
||||||
|
|
|
@ -151,7 +151,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
||||||
for ranking in rankings:
|
for ranking in rankings:
|
||||||
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
|
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
|
||||||
|
|
||||||
# Apply max_num_results limit if specified
|
# Apply max_num_results limit
|
||||||
if max_num_results is not None:
|
if max_num_results is not None:
|
||||||
rerank_data = rerank_data[:max_num_results]
|
rerank_data = rerank_data[:max_num_results]
|
||||||
|
|
||||||
|
|
|
@ -120,6 +120,10 @@ def pytest_addoption(parser):
|
||||||
"--embedding-model",
|
"--embedding-model",
|
||||||
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
|
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
|
||||||
)
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--rerank-model",
|
||||||
|
help="comma-separated list of rerank models. Fixture name: rerank_model_id",
|
||||||
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--safety-shield",
|
"--safety-shield",
|
||||||
help="comma-separated list of safety shields. Fixture name: shield_id",
|
help="comma-separated list of safety shields. Fixture name: shield_id",
|
||||||
|
@ -198,6 +202,7 @@ def pytest_generate_tests(metafunc):
|
||||||
"shield_id": ("--safety-shield", "shield"),
|
"shield_id": ("--safety-shield", "shield"),
|
||||||
"judge_model_id": ("--judge-model", "judge"),
|
"judge_model_id": ("--judge-model", "judge"),
|
||||||
"embedding_dimension": ("--embedding-dimension", "dim"),
|
"embedding_dimension": ("--embedding-dimension", "dim"),
|
||||||
|
"rerank_model_id": ("--rerank-model", "rerank"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Collect all parameters and their values
|
# Collect all parameters and their values
|
||||||
|
|
|
@ -119,6 +119,7 @@ def client_with_models(
|
||||||
embedding_model_id,
|
embedding_model_id,
|
||||||
embedding_dimension,
|
embedding_dimension,
|
||||||
judge_model_id,
|
judge_model_id,
|
||||||
|
rerank_model_id,
|
||||||
):
|
):
|
||||||
client = llama_stack_client
|
client = llama_stack_client
|
||||||
|
|
||||||
|
@ -151,6 +152,13 @@ def client_with_models(
|
||||||
model_type="embedding",
|
model_type="embedding",
|
||||||
metadata={"embedding_dimension": embedding_dimension or 384},
|
metadata={"embedding_dimension": embedding_dimension or 384},
|
||||||
)
|
)
|
||||||
|
if rerank_model_id and rerank_model_id not in model_ids:
|
||||||
|
rerank_provider = providers[0]
|
||||||
|
client.models.register(
|
||||||
|
model_id=rerank_model_id,
|
||||||
|
provider_id=rerank_provider.provider_id,
|
||||||
|
model_type="rerank",
|
||||||
|
)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,7 +174,7 @@ def model_providers(llama_stack_client):
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def skip_if_no_model(request):
|
def skip_if_no_model(request):
|
||||||
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"]
|
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id", "rerank_model_id"]
|
||||||
test_func = request.node.function
|
test_func = request.node.function
|
||||||
|
|
||||||
actual_params = inspect.signature(test_func).parameters.keys()
|
actual_params = inspect.signature(test_func).parameters.keys()
|
||||||
|
|
147
tests/integration/inference/test_rerank.py
Normal file
147
tests/integration/inference/test_rerank.py
Normal file
|
@ -0,0 +1,147 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||||
|
from llama_stack_client.types import RerankResponse
|
||||||
|
from llama_stack_client.types.shared.interleaved_content import (
|
||||||
|
ImageContentItem,
|
||||||
|
ImageContentItemImage,
|
||||||
|
ImageContentItemImageURL,
|
||||||
|
TextContentItem,
|
||||||
|
)
|
||||||
|
|
||||||
|
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
DUMMY_STRING = "string_1"
|
||||||
|
DUMMY_STRING2 = "string_2"
|
||||||
|
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
|
||||||
|
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
|
||||||
|
DUMMY_IMAGE_URL = ImageContentItem(
|
||||||
|
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
|
||||||
|
)
|
||||||
|
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
||||||
|
|
||||||
|
SUPPORTED_PROVIDERS = {"remote::nvidia"}
|
||||||
|
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_rerank_response(response: RerankResponse, items: list) -> None:
|
||||||
|
"""
|
||||||
|
Validate that a rerank response has the correct structure and ordering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The RerankResponse to validate
|
||||||
|
items: The original items list that was ranked
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If any validation fails
|
||||||
|
"""
|
||||||
|
seen = set()
|
||||||
|
last_score = float("inf")
|
||||||
|
for d in response.data:
|
||||||
|
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
|
||||||
|
assert d.index not in seen, f"Duplicate index {d.index} found"
|
||||||
|
seen.add(d.index)
|
||||||
|
assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}"
|
||||||
|
assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}"
|
||||||
|
last_score = d.relevance_score
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,items",
|
||||||
|
[
|
||||||
|
(DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]),
|
||||||
|
(DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]),
|
||||||
|
(DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]),
|
||||||
|
(DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"string-query-string-items",
|
||||||
|
"text-query-text-items",
|
||||||
|
"mixed-content-1",
|
||||||
|
"mixed-content-2",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_rerank_text(llama_stack_client, rerank_model_id, query, items, inference_provider_type):
|
||||||
|
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||||
|
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||||
|
|
||||||
|
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||||
|
assert isinstance(response, RerankResponse)
|
||||||
|
assert len(response.data) <= len(items)
|
||||||
|
_validate_rerank_response(response, items)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,items",
|
||||||
|
[
|
||||||
|
(DUMMY_IMAGE_URL, [DUMMY_STRING]),
|
||||||
|
(DUMMY_IMAGE_BASE64, [DUMMY_TEXT]),
|
||||||
|
(DUMMY_TEXT, [DUMMY_IMAGE_URL]),
|
||||||
|
(DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||||
|
(DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"image-query-url",
|
||||||
|
"image-query-base64",
|
||||||
|
"text-query-image-item",
|
||||||
|
"mixed-content-1",
|
||||||
|
"mixed-content-2",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_rerank_image(llama_stack_client, rerank_model_id, query, items, inference_provider_type):
|
||||||
|
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||||
|
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||||
|
|
||||||
|
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
|
||||||
|
error_type = (
|
||||||
|
ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||||
|
)
|
||||||
|
with pytest.raises(error_type):
|
||||||
|
llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||||
|
else:
|
||||||
|
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||||
|
|
||||||
|
assert isinstance(response, RerankResponse)
|
||||||
|
assert len(response.data) <= len(items)
|
||||||
|
_validate_rerank_response(response, items)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rerank_max_results(llama_stack_client, rerank_model_id, inference_provider_type):
|
||||||
|
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||||
|
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||||
|
|
||||||
|
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
|
||||||
|
max_num_results = 2
|
||||||
|
|
||||||
|
response = llama_stack_client.inference.rerank(
|
||||||
|
model=rerank_model_id,
|
||||||
|
query=DUMMY_STRING,
|
||||||
|
items=items,
|
||||||
|
max_num_results=max_num_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, RerankResponse)
|
||||||
|
assert len(response.data) == max_num_results
|
||||||
|
_validate_rerank_response(response, items)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rerank_max_results_larger_than_items(llama_stack_client, rerank_model_id, inference_provider_type):
|
||||||
|
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||||
|
pytest.xfail(f"{inference_provider_type} doesn't support rerank yet")
|
||||||
|
|
||||||
|
items = [DUMMY_STRING, DUMMY_STRING2]
|
||||||
|
response = llama_stack_client.inference.rerank(
|
||||||
|
model=rerank_model_id,
|
||||||
|
query=DUMMY_STRING,
|
||||||
|
items=items,
|
||||||
|
max_num_results=10, # Larger than items length
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, RerankResponse)
|
||||||
|
assert len(response.data) <= len(items) # Should return at most len(items)
|
180
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
180
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
||||||
|
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, status=200, json_data=None, text_data="OK"):
|
||||||
|
self.status = status
|
||||||
|
self._json_data = json_data or {"rankings": []}
|
||||||
|
self._text_data = text_data
|
||||||
|
|
||||||
|
async def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
async def text(self):
|
||||||
|
return self._text_data
|
||||||
|
|
||||||
|
|
||||||
|
class MockSession:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
self.post_calls = []
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def post(self, url, **kwargs):
|
||||||
|
self.post_calls.append((url, kwargs))
|
||||||
|
|
||||||
|
class PostContext:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return PostContext(self.response)
|
||||||
|
|
||||||
|
|
||||||
|
def create_adapter(config=None, model_metadata=None):
|
||||||
|
if config is None:
|
||||||
|
config = NVIDIAConfig(api_key="test-key")
|
||||||
|
|
||||||
|
adapter = NVIDIAInferenceAdapter(config)
|
||||||
|
|
||||||
|
class MockModel:
|
||||||
|
provider_resource_id = "test-model"
|
||||||
|
metadata = model_metadata or {}
|
||||||
|
|
||||||
|
adapter.model_store = AsyncMock()
|
||||||
|
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
|
||||||
|
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rerank_basic_functionality():
|
||||||
|
adapter = create_adapter()
|
||||||
|
mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]})
|
||||||
|
mock_session = MockSession(mock_response)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"])
|
||||||
|
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0].index == 0
|
||||||
|
assert result.data[0].relevance_score == 0.5
|
||||||
|
|
||||||
|
url, kwargs = mock_session.post_calls[0]
|
||||||
|
payload = kwargs["json"]
|
||||||
|
assert payload["model"] == "test-model"
|
||||||
|
assert payload["query"] == {"text": "test query"}
|
||||||
|
assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_rankings_key():
|
||||||
|
adapter = create_adapter()
|
||||||
|
mock_session = MockSession(MockResponse(json_data={}))
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
result = await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
assert len(result.data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hosted_with_endpoint():
|
||||||
|
adapter = create_adapter(
|
||||||
|
config=NVIDIAConfig(api_key="key"), model_metadata={"endpoint": "https://model.endpoint/rerank"}
|
||||||
|
)
|
||||||
|
mock_session = MockSession(MockResponse())
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
url, _ = mock_session.post_calls[0]
|
||||||
|
assert url == "https://model.endpoint/rerank"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hosted_without_endpoint():
|
||||||
|
adapter = create_adapter(
|
||||||
|
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
|
||||||
|
model_metadata={}, # No "endpoint" key
|
||||||
|
)
|
||||||
|
mock_session = MockSession(MockResponse())
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
url, _ = mock_session.post_calls[0]
|
||||||
|
assert "https://integrate.api.nvidia.com" in url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_self_hosted_ignores_endpoint():
|
||||||
|
adapter = create_adapter(
|
||||||
|
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
||||||
|
model_metadata={"endpoint": "https://model.endpoint/rerank"}, # This should be ignored.
|
||||||
|
)
|
||||||
|
mock_session = MockSession(MockResponse())
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
url, _ = mock_session.post_calls[0]
|
||||||
|
assert "http://localhost:8000" in url
|
||||||
|
assert "model.endpoint/rerank" not in url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_num_results():
|
||||||
|
adapter = create_adapter()
|
||||||
|
rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}]
|
||||||
|
mock_session = MockSession(MockResponse(json_data={"rankings": rankings}))
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1)
|
||||||
|
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0].index == 0
|
||||||
|
assert result.data[0].relevance_score == 0.8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_error():
|
||||||
|
adapter = create_adapter()
|
||||||
|
mock_session = MockSession(MockResponse(status=500, text_data="Server Error"))
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
with pytest.raises(ConnectionError, match="status 500.*Server Error"):
|
||||||
|
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_error():
|
||||||
|
adapter = create_adapter()
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error")
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
|
||||||
|
await adapter.rerank(model="test-model", query="q", items=["a"])
|
Loading…
Add table
Add a link
Reference in a new issue