mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Add tests
This commit is contained in:
parent
bab9d7aaea
commit
d7cbeb4b8c
6 changed files with 345 additions and 5 deletions
|
@ -120,6 +120,10 @@ def pytest_addoption(parser):
|
|||
"--embedding-model",
|
||||
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
|
||||
)
|
||||
parser.addoption(
|
||||
"--rerank-model",
|
||||
help="comma-separated list of rerank models. Fixture name: rerank_model_id",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-shield",
|
||||
help="comma-separated list of safety shields. Fixture name: shield_id",
|
||||
|
@ -198,6 +202,7 @@ def pytest_generate_tests(metafunc):
|
|||
"shield_id": ("--safety-shield", "shield"),
|
||||
"judge_model_id": ("--judge-model", "judge"),
|
||||
"embedding_dimension": ("--embedding-dimension", "dim"),
|
||||
"rerank_model_id": ("--rerank-model", "rerank"),
|
||||
}
|
||||
|
||||
# Collect all parameters and their values
|
||||
|
|
|
@ -119,6 +119,7 @@ def client_with_models(
|
|||
embedding_model_id,
|
||||
embedding_dimension,
|
||||
judge_model_id,
|
||||
rerank_model_id,
|
||||
):
|
||||
client = llama_stack_client
|
||||
|
||||
|
@ -151,6 +152,13 @@ def client_with_models(
|
|||
model_type="embedding",
|
||||
metadata={"embedding_dimension": embedding_dimension or 384},
|
||||
)
|
||||
if rerank_model_id and rerank_model_id not in model_ids:
|
||||
rerank_provider = providers[0]
|
||||
client.models.register(
|
||||
model_id=rerank_model_id,
|
||||
provider_id=rerank_provider.provider_id,
|
||||
model_type="rerank",
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
|
@ -166,7 +174,7 @@ def model_providers(llama_stack_client):
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_if_no_model(request):
|
||||
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"]
|
||||
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id", "rerank_model_id"]
|
||||
test_func = request.node.function
|
||||
|
||||
actual_params = inspect.signature(test_func).parameters.keys()
|
||||
|
|
147
tests/integration/inference/test_rerank.py
Normal file
147
tests/integration/inference/test_rerank.py
Normal file
|
@ -0,0 +1,147 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||
from llama_stack_client.types import RerankResponse
|
||||
from llama_stack_client.types.shared.interleaved_content import (
|
||||
ImageContentItem,
|
||||
ImageContentItemImage,
|
||||
ImageContentItemImageURL,
|
||||
TextContentItem,
|
||||
)
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
# Test data
|
||||
DUMMY_STRING = "string_1"
|
||||
DUMMY_STRING2 = "string_2"
|
||||
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
|
||||
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
|
||||
DUMMY_IMAGE_URL = ImageContentItem(
|
||||
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
|
||||
)
|
||||
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
||||
|
||||
SUPPORTED_PROVIDERS = {"remote::nvidia"}
|
||||
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
|
||||
|
||||
|
||||
def _validate_rerank_response(response: RerankResponse, items: list) -> None:
|
||||
"""
|
||||
Validate that a rerank response has the correct structure and ordering.
|
||||
|
||||
Args:
|
||||
response: The RerankResponse to validate
|
||||
items: The original items list that was ranked
|
||||
|
||||
Raises:
|
||||
AssertionError: If any validation fails
|
||||
"""
|
||||
seen = set()
|
||||
last_score = float("inf")
|
||||
for d in response.data:
|
||||
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
|
||||
assert d.index not in seen, f"Duplicate index {d.index} found"
|
||||
seen.add(d.index)
|
||||
assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}"
|
||||
assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}"
|
||||
last_score = d.relevance_score
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items",
|
||||
[
|
||||
(DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]),
|
||||
(DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]),
|
||||
(DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]),
|
||||
],
|
||||
ids=[
|
||||
"string-query-string-items",
|
||||
"text-query-text-items",
|
||||
"mixed-content-1",
|
||||
"mixed-content-2",
|
||||
],
|
||||
)
|
||||
def test_rerank_text(llama_stack_client, rerank_model_id, query, items, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||
|
||||
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
assert isinstance(response, RerankResponse)
|
||||
assert len(response.data) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items",
|
||||
[
|
||||
(DUMMY_IMAGE_URL, [DUMMY_STRING]),
|
||||
(DUMMY_IMAGE_BASE64, [DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_IMAGE_URL]),
|
||||
(DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||
],
|
||||
ids=[
|
||||
"image-query-url",
|
||||
"image-query-base64",
|
||||
"text-query-image-item",
|
||||
"mixed-content-1",
|
||||
"mixed-content-2",
|
||||
],
|
||||
)
|
||||
def test_rerank_image(llama_stack_client, rerank_model_id, query, items, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||
|
||||
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
|
||||
error_type = (
|
||||
ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||
)
|
||||
with pytest.raises(error_type):
|
||||
llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
else:
|
||||
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
assert isinstance(response, RerankResponse)
|
||||
assert len(response.data) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
def test_rerank_max_results(llama_stack_client, rerank_model_id, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
|
||||
max_num_results = 2
|
||||
|
||||
response = llama_stack_client.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
max_num_results=max_num_results,
|
||||
)
|
||||
|
||||
assert isinstance(response, RerankResponse)
|
||||
assert len(response.data) == max_num_results
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
def test_rerank_max_results_larger_than_items(llama_stack_client, rerank_model_id, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank yet")
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2]
|
||||
response = llama_stack_client.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
max_num_results=10, # Larger than items length
|
||||
)
|
||||
|
||||
assert isinstance(response, RerankResponse)
|
||||
assert len(response.data) <= len(items) # Should return at most len(items)
|
Loading…
Add table
Add a link
Reference in a new issue