mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
Merge branch 'main' into feat/gunicorn-production-server
This commit is contained in:
commit
47bd994824
59 changed files with 3190 additions and 421 deletions
|
|
@ -171,6 +171,10 @@ def pytest_addoption(parser):
|
|||
"--embedding-model",
|
||||
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
|
||||
)
|
||||
parser.addoption(
|
||||
"--rerank-model",
|
||||
help="comma-separated list of rerank models. Fixture name: rerank_model_id",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-shield",
|
||||
help="comma-separated list of safety shields. Fixture name: shield_id",
|
||||
|
|
@ -249,6 +253,7 @@ def pytest_generate_tests(metafunc):
|
|||
"shield_id": ("--safety-shield", "shield"),
|
||||
"judge_model_id": ("--judge-model", "judge"),
|
||||
"embedding_dimension": ("--embedding-dimension", "dim"),
|
||||
"rerank_model_id": ("--rerank-model", "rerank"),
|
||||
}
|
||||
|
||||
# Collect all parameters and their values
|
||||
|
|
|
|||
|
|
@ -153,6 +153,7 @@ def client_with_models(
|
|||
vision_model_id,
|
||||
embedding_model_id,
|
||||
judge_model_id,
|
||||
rerank_model_id,
|
||||
):
|
||||
client = llama_stack_client
|
||||
|
||||
|
|
@ -170,6 +171,9 @@ def client_with_models(
|
|||
|
||||
if embedding_model_id and embedding_model_id not in model_ids:
|
||||
raise ValueError(f"embedding_model_id {embedding_model_id} not found")
|
||||
|
||||
if rerank_model_id and rerank_model_id not in model_ids:
|
||||
raise ValueError(f"rerank_model_id {rerank_model_id} not found")
|
||||
return client
|
||||
|
||||
|
||||
|
|
@ -185,7 +189,14 @@ def model_providers(llama_stack_client):
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_if_no_model(request):
|
||||
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"]
|
||||
model_fixtures = [
|
||||
"text_model_id",
|
||||
"vision_model_id",
|
||||
"embedding_model_id",
|
||||
"judge_model_id",
|
||||
"shield_id",
|
||||
"rerank_model_id",
|
||||
]
|
||||
test_func = request.node.function
|
||||
|
||||
actual_params = inspect.signature(test_func).parameters.keys()
|
||||
|
|
@ -230,6 +241,7 @@ def instantiate_llama_stack_client(session):
|
|||
|
||||
force_restart = os.environ.get("LLAMA_STACK_TEST_FORCE_SERVER_RESTART") == "1"
|
||||
if force_restart:
|
||||
print(f"Forcing restart of the server on port {port}")
|
||||
stop_server_on_port(port)
|
||||
|
||||
# Check if port is available
|
||||
|
|
|
|||
|
|
@ -721,6 +721,6 @@ def test_openai_chat_completion_structured_output(openai_client, text_model_id,
|
|||
print(response.choices[0].message.content)
|
||||
answer = AnswerFormat.model_validate_json(response.choices[0].message.content)
|
||||
expected = tc["expected"]
|
||||
assert answer.first_name == expected["first_name"]
|
||||
assert answer.last_name == expected["last_name"]
|
||||
assert expected["first_name"].lower() in answer.first_name.lower()
|
||||
assert expected["last_name"].lower() in answer.last_name.lower()
|
||||
assert answer.year_of_birth == expected["year_of_birth"]
|
||||
|
|
|
|||
214
tests/integration/inference/test_rerank.py
Normal file
214
tests/integration/inference/test_rerank.py
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||
from llama_stack_client.types.alpha import InferenceRerankResponse
|
||||
from llama_stack_client.types.shared.interleaved_content import (
|
||||
ImageContentItem,
|
||||
ImageContentItemImage,
|
||||
ImageContentItemImageURL,
|
||||
TextContentItem,
|
||||
)
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
# Test data
|
||||
DUMMY_STRING = "string_1"
|
||||
DUMMY_STRING2 = "string_2"
|
||||
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
|
||||
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
|
||||
DUMMY_IMAGE_URL = ImageContentItem(
|
||||
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
|
||||
)
|
||||
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
||||
|
||||
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
|
||||
|
||||
|
||||
def skip_if_provider_doesnt_support_rerank(inference_provider_type):
|
||||
supported_providers = {"remote::nvidia"}
|
||||
if inference_provider_type not in supported_providers:
|
||||
pytest.skip(f"{inference_provider_type} doesn't support rerank models")
|
||||
|
||||
|
||||
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
|
||||
"""
|
||||
Validate that a rerank response has the correct structure and ordering.
|
||||
|
||||
Args:
|
||||
response: The InferenceRerankResponse to validate
|
||||
items: The original items list that was ranked
|
||||
|
||||
Raises:
|
||||
AssertionError: If any validation fails
|
||||
"""
|
||||
seen = set()
|
||||
last_score = float("inf")
|
||||
for d in response:
|
||||
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
|
||||
assert d.index not in seen, f"Duplicate index {d.index} found"
|
||||
seen.add(d.index)
|
||||
assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}"
|
||||
assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}"
|
||||
last_score = d.relevance_score
|
||||
|
||||
|
||||
def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None:
|
||||
"""
|
||||
Validate that the expected most relevant item ranks first.
|
||||
|
||||
Args:
|
||||
response: The InferenceRerankResponse to validate
|
||||
items: The original items list that was ranked
|
||||
expected_first_item: The expected first item in the ranking
|
||||
|
||||
Raises:
|
||||
AssertionError: If any validation fails
|
||||
"""
|
||||
if not response:
|
||||
raise AssertionError("No ranking data returned in response")
|
||||
|
||||
actual_first_index = response[0].index
|
||||
actual_first_item = items[actual_first_index]
|
||||
assert actual_first_item == expected_first_item, (
|
||||
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items",
|
||||
[
|
||||
(DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]),
|
||||
(DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]),
|
||||
(DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]),
|
||||
],
|
||||
ids=[
|
||||
"string-query-string-items",
|
||||
"text-query-text-items",
|
||||
"mixed-content-1",
|
||||
"mixed-content-2",
|
||||
],
|
||||
)
|
||||
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
assert isinstance(response, list)
|
||||
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
|
||||
assert len(response) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items",
|
||||
[
|
||||
(DUMMY_IMAGE_URL, [DUMMY_STRING]),
|
||||
(DUMMY_IMAGE_BASE64, [DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_IMAGE_URL]),
|
||||
(DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||
],
|
||||
ids=[
|
||||
"image-query-url",
|
||||
"image-query-base64",
|
||||
"text-query-image-item",
|
||||
"mixed-content-1",
|
||||
"mixed-content-2",
|
||||
],
|
||||
)
|
||||
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
|
||||
error_type = (
|
||||
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||
)
|
||||
with pytest.raises(error_type):
|
||||
client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
else:
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
|
||||
max_num_results = 2
|
||||
|
||||
response = client_with_models.alpha.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
max_num_results=max_num_results,
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == max_num_results
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2]
|
||||
response = client_with_models.alpha.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
max_num_results=10, # Larger than items length
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) <= len(items) # Should return at most len(items)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items,expected_first_item",
|
||||
[
|
||||
(
|
||||
"What is a reranking model? ",
|
||||
[
|
||||
"A reranking model reranks a list of items based on the query. ",
|
||||
"Machine learning algorithms learn patterns from data. ",
|
||||
"Python is a programming language. ",
|
||||
],
|
||||
"A reranking model reranks a list of items based on the query. ",
|
||||
),
|
||||
(
|
||||
"What is C++?",
|
||||
[
|
||||
"Learning new things is interesting. ",
|
||||
"C++ is a programming language. ",
|
||||
"Books provide knowledge and entertainment. ",
|
||||
],
|
||||
"C++ is a programming language. ",
|
||||
),
|
||||
(
|
||||
"What are good learning habits? ",
|
||||
[
|
||||
"Cooking pasta is a fun activity. ",
|
||||
"Plants need water and sunlight. ",
|
||||
"Good learning habits include reading daily and taking notes. ",
|
||||
],
|
||||
"Good learning habits include reading daily and taking notes. ",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rerank_semantic_correctness(
|
||||
client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type
|
||||
):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
_validate_rerank_response(response, items)
|
||||
_validate_semantic_ranking(response, items, expected_first_item)
|
||||
|
|
@ -4,18 +4,75 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
|
||||
|
||||
class TestInspect:
|
||||
@pytest.mark.skip(reason="inspect tests disabled")
|
||||
def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
health = llama_stack_client.inspect.health()
|
||||
assert health is not None
|
||||
assert health.status == "OK"
|
||||
|
||||
@pytest.mark.skip(reason="inspect tests disabled")
|
||||
def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
version = llama_stack_client.inspect.version()
|
||||
assert version is not None
|
||||
assert version.version is not None
|
||||
|
||||
@pytest.mark.skip(reason="inspect tests disabled")
|
||||
def test_list_routes_default(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
"""Test list_routes with default filter (non-deprecated v1 routes)."""
|
||||
response = llama_stack_client.routes.list()
|
||||
assert response is not None
|
||||
assert hasattr(response, "data")
|
||||
routes = response.data
|
||||
assert len(routes) > 0
|
||||
|
||||
# All routes should be non-deprecated
|
||||
# Check that we don't see any /openai/ routes (which are deprecated)
|
||||
openai_routes = [r for r in routes if "/openai/" in r.route]
|
||||
assert len(openai_routes) == 0, "Default filter should not include deprecated /openai/ routes"
|
||||
|
||||
# Should see standard v1 routes like /inspect/routes, /health, /version
|
||||
paths = [r.route for r in routes]
|
||||
assert "/inspect/routes" in paths or "/v1/inspect/routes" in paths
|
||||
assert "/health" in paths or "/v1/health" in paths
|
||||
|
||||
@pytest.mark.skip(reason="inspect tests disabled")
|
||||
def test_list_routes_filter_by_deprecated(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
"""Test list_routes with deprecated filter."""
|
||||
response = llama_stack_client.routes.list(api_filter="deprecated")
|
||||
assert response is not None
|
||||
assert hasattr(response, "data")
|
||||
routes = response.data
|
||||
|
||||
# When filtering for deprecated, we should get deprecated routes
|
||||
# At minimum, we should see some /openai/ routes which are deprecated
|
||||
if len(routes) > 0:
|
||||
# If there are any deprecated routes, they should include openai routes
|
||||
openai_routes = [r for r in routes if "/openai/" in r.route]
|
||||
assert len(openai_routes) > 0, "Deprecated filter should include /openai/ routes"
|
||||
|
||||
@pytest.mark.skip(reason="inspect tests disabled")
|
||||
def test_list_routes_filter_by_v1(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
"""Test list_routes with v1 filter."""
|
||||
response = llama_stack_client.routes.list(api_filter="v1")
|
||||
assert response is not None
|
||||
assert hasattr(response, "data")
|
||||
routes = response.data
|
||||
assert len(routes) > 0
|
||||
|
||||
# Should not include deprecated routes
|
||||
openai_routes = [r for r in routes if "/openai/" in r.route]
|
||||
assert len(openai_routes) == 0
|
||||
|
||||
# Should include v1 routes
|
||||
paths = [r.route for r in routes]
|
||||
assert any(
|
||||
"/v1/" in p or p.startswith("/inspect/") or p.startswith("/health") or p.startswith("/version")
|
||||
for p in paths
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import os
|
|||
|
||||
import pytest
|
||||
|
||||
import llama_stack.core.telemetry.telemetry as telemetry_module
|
||||
from llama_stack.testing.api_recorder import patch_httpx_for_test_id
|
||||
from tests.integration.fixtures.common import instantiate_llama_stack_client
|
||||
from tests.integration.telemetry.collectors import InMemoryTelemetryManager, OtlpHttpTestCollector
|
||||
|
|
@ -22,40 +21,26 @@ def telemetry_test_collector():
|
|||
stack_mode = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
|
||||
if stack_mode == "server":
|
||||
# In server mode, the collector must be started and the server is already running.
|
||||
# The integration test script (scripts/integration-tests.sh) should have set
|
||||
# LLAMA_STACK_TEST_COLLECTOR_PORT and OTEL_EXPORTER_OTLP_ENDPOINT before starting the server.
|
||||
try:
|
||||
collector = OtlpHttpTestCollector()
|
||||
except RuntimeError as exc:
|
||||
pytest.skip(str(exc))
|
||||
env_overrides = {
|
||||
"OTEL_EXPORTER_OTLP_ENDPOINT": collector.endpoint,
|
||||
"OTEL_EXPORTER_OTLP_PROTOCOL": "http/protobuf",
|
||||
"OTEL_BSP_SCHEDULE_DELAY": "200",
|
||||
"OTEL_BSP_EXPORT_TIMEOUT": "2000",
|
||||
"LLAMA_STACK_DISABLE_GUNICORN": "true", # Disable multi-process for telemetry collection
|
||||
}
|
||||
|
||||
previous_env = {key: os.environ.get(key) for key in env_overrides}
|
||||
previous_force_restart = os.environ.get("LLAMA_STACK_TEST_FORCE_SERVER_RESTART")
|
||||
|
||||
for key, value in env_overrides.items():
|
||||
os.environ[key] = value
|
||||
|
||||
os.environ["LLAMA_STACK_TEST_FORCE_SERVER_RESTART"] = "1"
|
||||
telemetry_module._TRACER_PROVIDER = None
|
||||
# Verify the collector is listening on the expected endpoint
|
||||
expected_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
|
||||
if expected_endpoint and collector.endpoint != expected_endpoint:
|
||||
pytest.skip(
|
||||
f"Collector endpoint mismatch: expected {expected_endpoint}, got {collector.endpoint}. "
|
||||
"Server was likely started before collector."
|
||||
)
|
||||
|
||||
try:
|
||||
yield collector
|
||||
finally:
|
||||
collector.shutdown()
|
||||
for key, prior in previous_env.items():
|
||||
if prior is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = prior
|
||||
if previous_force_restart is None:
|
||||
os.environ.pop("LLAMA_STACK_TEST_FORCE_SERVER_RESTART", None)
|
||||
else:
|
||||
os.environ["LLAMA_STACK_TEST_FORCE_SERVER_RESTART"] = previous_force_restart
|
||||
else:
|
||||
manager = InMemoryTelemetryManager()
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -206,3 +206,65 @@ def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
|
|||
def test_parse_and_maybe_upgrade_config_image_name_int(config_with_image_name_int):
|
||||
result = parse_and_maybe_upgrade_config(config_with_image_name_int)
|
||||
assert isinstance(result.image_name, str)
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_sets_external_providers_dir(up_to_date_config):
|
||||
"""Test that external_providers_dir is None when not specified (deprecated field)."""
|
||||
# Ensure the config doesn't have external_providers_dir set
|
||||
assert "external_providers_dir" not in up_to_date_config
|
||||
|
||||
result = parse_and_maybe_upgrade_config(up_to_date_config)
|
||||
|
||||
# Verify external_providers_dir is None (not set to default)
|
||||
# This aligns with the deprecation of external_providers_dir
|
||||
assert result.external_providers_dir is None
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_preserves_custom_external_providers_dir(up_to_date_config):
|
||||
"""Test that custom external_providers_dir values are preserved."""
|
||||
custom_dir = "/custom/providers/dir"
|
||||
up_to_date_config["external_providers_dir"] = custom_dir
|
||||
|
||||
result = parse_and_maybe_upgrade_config(up_to_date_config)
|
||||
|
||||
# Verify the custom value was preserved
|
||||
assert str(result.external_providers_dir) == custom_dir
|
||||
|
||||
|
||||
def test_generate_run_config_from_providers():
|
||||
"""Test that _generate_run_config_from_providers creates a valid config"""
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.stack.run import StackRun
|
||||
from llama_stack.core.datatypes import Provider
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
stack_run = StackRun(subparsers)
|
||||
|
||||
providers = {
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_type="inline::meta-reference",
|
||||
provider_id="meta-reference",
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
config = stack_run._generate_run_config_from_providers(providers=providers)
|
||||
config_dict = config.model_dump(mode="json")
|
||||
|
||||
# Verify basic structure
|
||||
assert config_dict["image_name"] == "providers-run"
|
||||
assert "inference" in config_dict["apis"]
|
||||
assert "inference" in config_dict["providers"]
|
||||
|
||||
# Verify storage has all required stores including prompts
|
||||
assert "storage" in config_dict
|
||||
stores = config_dict["storage"]["stores"]
|
||||
assert "prompts" in stores
|
||||
assert stores["prompts"]["namespace"] == "prompts"
|
||||
|
||||
# Verify config can be parsed back
|
||||
parsed = parse_and_maybe_upgrade_config(config_dict)
|
||||
assert parsed.image_name == "providers-run"
|
||||
|
|
|
|||
251
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
251
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status=200, json_data=None, text_data="OK"):
|
||||
self.status = status
|
||||
self._json_data = json_data or {"rankings": []}
|
||||
self._text_data = text_data
|
||||
|
||||
async def json(self):
|
||||
return self._json_data
|
||||
|
||||
async def text(self):
|
||||
return self._text_data
|
||||
|
||||
|
||||
class MockSession:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self.post_calls = []
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
self.post_calls.append((url, kwargs))
|
||||
|
||||
class PostContext:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.response
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
return PostContext(self.response)
|
||||
|
||||
|
||||
def create_adapter(config=None, rerank_endpoints=None):
|
||||
if config is None:
|
||||
config = NVIDIAConfig(api_key="test-key")
|
||||
|
||||
adapter = NVIDIAInferenceAdapter(config=config)
|
||||
|
||||
class MockModel:
|
||||
provider_resource_id = "test-model"
|
||||
metadata = {}
|
||||
|
||||
adapter.model_store = AsyncMock()
|
||||
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
|
||||
|
||||
if rerank_endpoints is not None:
|
||||
adapter.config.rerank_model_to_url = rerank_endpoints
|
||||
|
||||
return adapter
|
||||
|
||||
|
||||
async def test_rerank_basic_functionality():
|
||||
adapter = create_adapter()
|
||||
mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]})
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"])
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].index == 0
|
||||
assert result.data[0].relevance_score == 0.5
|
||||
|
||||
url, kwargs = mock_session.post_calls[0]
|
||||
payload = kwargs["json"]
|
||||
assert payload["model"] == "test-model"
|
||||
assert payload["query"] == {"text": "test query"}
|
||||
assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}]
|
||||
|
||||
|
||||
async def test_missing_rankings_key():
|
||||
adapter = create_adapter()
|
||||
mock_session = MockSession(MockResponse(json_data={}))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
assert len(result.data) == 0
|
||||
|
||||
|
||||
async def test_hosted_with_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"}
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert url == "https://model.endpoint/rerank"
|
||||
|
||||
|
||||
async def test_hosted_without_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
|
||||
rerank_endpoints={}, # No endpoint mapping for test-model
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "https://integrate.api.nvidia.com" in url
|
||||
|
||||
|
||||
async def test_hosted_model_not_in_endpoint_mapping():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"}
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "https://integrate.api.nvidia.com" in url
|
||||
assert url != "https://other.endpoint/rerank"
|
||||
|
||||
|
||||
async def test_self_hosted_ignores_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
||||
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "http://localhost:8000" in url
|
||||
assert "model.endpoint/rerank" not in url
|
||||
|
||||
|
||||
async def test_max_num_results():
|
||||
adapter = create_adapter()
|
||||
rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}]
|
||||
mock_session = MockSession(MockResponse(json_data={"rankings": rankings}))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].index == 0
|
||||
assert result.data[0].relevance_score == 0.8
|
||||
|
||||
|
||||
async def test_http_error():
|
||||
adapter = create_adapter()
|
||||
mock_session = MockSession(MockResponse(status=500, text_data="Server Error"))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(ConnectionError, match="status 500.*Server Error"):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
|
||||
async def test_client_error():
|
||||
adapter = create_adapter()
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error")
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
|
||||
async def test_list_models_includes_configured_rerank_models():
|
||||
"""Test that list_models adds rerank models to the dynamic model list."""
|
||||
adapter = create_adapter()
|
||||
adapter.__provider_id__ = "nvidia"
|
||||
adapter.__provider_spec__ = MagicMock()
|
||||
|
||||
dynamic_ids = ["llm-1", "embedding-1"]
|
||||
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)):
|
||||
result = await adapter.list_models()
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Check that the rerank models are added
|
||||
model_ids = [m.identifier for m in result]
|
||||
assert "nv-rerank-qa-mistral-4b:1" in model_ids
|
||||
assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids
|
||||
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids
|
||||
|
||||
rerank_models = [m for m in result if m.model_type == ModelType.rerank]
|
||||
|
||||
assert len(rerank_models) == 3
|
||||
|
||||
for m in rerank_models:
|
||||
assert m.provider_id == "nvidia"
|
||||
assert m.model_type == ModelType.rerank
|
||||
assert m.metadata == {}
|
||||
assert m.identifier in adapter._model_cache
|
||||
|
||||
|
||||
async def test_list_provider_model_ids_has_no_duplicates():
|
||||
adapter = create_adapter()
|
||||
|
||||
dynamic_ids = [
|
||||
"llm-1",
|
||||
"nvidia/nv-rerankqa-mistral-4b-v3", # overlaps configured rerank ids
|
||||
"embedding-1",
|
||||
"llm-1",
|
||||
]
|
||||
|
||||
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)):
|
||||
ids = list(await adapter.list_provider_model_ids())
|
||||
|
||||
assert len(ids) == len(set(ids))
|
||||
assert ids.count("nvidia/nv-rerankqa-mistral-4b-v3") == 1
|
||||
assert "nv-rerank-qa-mistral-4b:1" in ids
|
||||
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in ids
|
||||
|
||||
|
||||
async def test_list_provider_model_ids_uses_configured_on_dynamic_failure():
|
||||
adapter = create_adapter()
|
||||
|
||||
# Simulate dynamic listing failure
|
||||
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(side_effect=Exception)):
|
||||
ids = list(await adapter.list_provider_model_ids())
|
||||
|
||||
# Should still return configured rerank ids
|
||||
configured_ids = list(adapter.config.rerank_model_to_url.keys())
|
||||
assert set(ids) == set(configured_ids)
|
||||
Loading…
Add table
Add a link
Reference in a new issue