Merge branch 'main' into feat/gunicorn-production-server

This commit is contained in:
Roy Belio 2025-11-02 16:13:15 +02:00 committed by GitHub
commit 47bd994824
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
59 changed files with 3190 additions and 421 deletions

View file

@ -171,6 +171,10 @@ def pytest_addoption(parser):
"--embedding-model",
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
)
parser.addoption(
"--rerank-model",
help="comma-separated list of rerank models. Fixture name: rerank_model_id",
)
parser.addoption(
"--safety-shield",
help="comma-separated list of safety shields. Fixture name: shield_id",
@ -249,6 +253,7 @@ def pytest_generate_tests(metafunc):
"shield_id": ("--safety-shield", "shield"),
"judge_model_id": ("--judge-model", "judge"),
"embedding_dimension": ("--embedding-dimension", "dim"),
"rerank_model_id": ("--rerank-model", "rerank"),
}
# Collect all parameters and their values

View file

@ -153,6 +153,7 @@ def client_with_models(
vision_model_id,
embedding_model_id,
judge_model_id,
rerank_model_id,
):
client = llama_stack_client
@ -170,6 +171,9 @@ def client_with_models(
if embedding_model_id and embedding_model_id not in model_ids:
raise ValueError(f"embedding_model_id {embedding_model_id} not found")
if rerank_model_id and rerank_model_id not in model_ids:
raise ValueError(f"rerank_model_id {rerank_model_id} not found")
return client
@ -185,7 +189,14 @@ def model_providers(llama_stack_client):
@pytest.fixture(autouse=True)
def skip_if_no_model(request):
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"]
model_fixtures = [
"text_model_id",
"vision_model_id",
"embedding_model_id",
"judge_model_id",
"shield_id",
"rerank_model_id",
]
test_func = request.node.function
actual_params = inspect.signature(test_func).parameters.keys()
@ -230,6 +241,7 @@ def instantiate_llama_stack_client(session):
force_restart = os.environ.get("LLAMA_STACK_TEST_FORCE_SERVER_RESTART") == "1"
if force_restart:
print(f"Forcing restart of the server on port {port}")
stop_server_on_port(port)
# Check if port is available

View file

@ -721,6 +721,6 @@ def test_openai_chat_completion_structured_output(openai_client, text_model_id,
print(response.choices[0].message.content)
answer = AnswerFormat.model_validate_json(response.choices[0].message.content)
expected = tc["expected"]
assert answer.first_name == expected["first_name"]
assert answer.last_name == expected["last_name"]
assert expected["first_name"].lower() in answer.first_name.lower()
assert expected["last_name"].lower() in answer.last_name.lower()
assert answer.year_of_birth == expected["year_of_birth"]

View file

@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
from llama_stack_client.types.alpha import InferenceRerankResponse
from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem,
ImageContentItemImage,
ImageContentItemImageURL,
TextContentItem,
)
from llama_stack.core.library_client import LlamaStackAsLibraryClient
# Test data
DUMMY_STRING = "string_1"
DUMMY_STRING2 = "string_2"
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
DUMMY_IMAGE_URL = ImageContentItem(
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
)
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
def skip_if_provider_doesnt_support_rerank(inference_provider_type):
supported_providers = {"remote::nvidia"}
if inference_provider_type not in supported_providers:
pytest.skip(f"{inference_provider_type} doesn't support rerank models")
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
"""
Validate that a rerank response has the correct structure and ordering.
Args:
response: The InferenceRerankResponse to validate
items: The original items list that was ranked
Raises:
AssertionError: If any validation fails
"""
seen = set()
last_score = float("inf")
for d in response:
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
assert d.index not in seen, f"Duplicate index {d.index} found"
seen.add(d.index)
assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}"
assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}"
last_score = d.relevance_score
def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None:
"""
Validate that the expected most relevant item ranks first.
Args:
response: The InferenceRerankResponse to validate
items: The original items list that was ranked
expected_first_item: The expected first item in the ranking
Raises:
AssertionError: If any validation fails
"""
if not response:
raise AssertionError("No ranking data returned in response")
actual_first_index = response[0].index
actual_first_item = items[actual_first_index]
assert actual_first_item == expected_first_item, (
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
)
@pytest.mark.parametrize(
"query,items",
[
(DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]),
(DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]),
(DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]),
],
ids=[
"string-query-string-items",
"text-query-text-items",
"mixed-content-1",
"mixed-content-2",
],
)
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, list)
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
assert len(response) <= len(items)
_validate_rerank_response(response, items)
@pytest.mark.parametrize(
"query,items",
[
(DUMMY_IMAGE_URL, [DUMMY_STRING]),
(DUMMY_IMAGE_BASE64, [DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_IMAGE_URL]),
(DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
],
ids=[
"image-query-url",
"image-query-base64",
"text-query-image-item",
"mixed-content-1",
"mixed-content-2",
],
)
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
error_type = (
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
)
with pytest.raises(error_type):
client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
else:
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, list)
assert len(response) <= len(items)
_validate_rerank_response(response, items)
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
max_num_results = 2
response = client_with_models.alpha.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
max_num_results=max_num_results,
)
assert isinstance(response, list)
assert len(response) == max_num_results
_validate_rerank_response(response, items)
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
items = [DUMMY_STRING, DUMMY_STRING2]
response = client_with_models.alpha.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
max_num_results=10, # Larger than items length
)
assert isinstance(response, list)
assert len(response) <= len(items) # Should return at most len(items)
@pytest.mark.parametrize(
"query,items,expected_first_item",
[
(
"What is a reranking model? ",
[
"A reranking model reranks a list of items based on the query. ",
"Machine learning algorithms learn patterns from data. ",
"Python is a programming language. ",
],
"A reranking model reranks a list of items based on the query. ",
),
(
"What is C++?",
[
"Learning new things is interesting. ",
"C++ is a programming language. ",
"Books provide knowledge and entertainment. ",
],
"C++ is a programming language. ",
),
(
"What are good learning habits? ",
[
"Cooking pasta is a fun activity. ",
"Plants need water and sunlight. ",
"Good learning habits include reading daily and taking notes. ",
],
"Good learning habits include reading daily and taking notes. ",
),
],
)
def test_rerank_semantic_correctness(
client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type
):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
_validate_rerank_response(response, items)
_validate_semantic_ranking(response, items, expected_first_item)

View file

@ -4,18 +4,75 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack_client import LlamaStackClient
from llama_stack import LlamaStackAsLibraryClient
class TestInspect:
@pytest.mark.skip(reason="inspect tests disabled")
def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
health = llama_stack_client.inspect.health()
assert health is not None
assert health.status == "OK"
@pytest.mark.skip(reason="inspect tests disabled")
def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
version = llama_stack_client.inspect.version()
assert version is not None
assert version.version is not None
@pytest.mark.skip(reason="inspect tests disabled")
def test_list_routes_default(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
"""Test list_routes with default filter (non-deprecated v1 routes)."""
response = llama_stack_client.routes.list()
assert response is not None
assert hasattr(response, "data")
routes = response.data
assert len(routes) > 0
# All routes should be non-deprecated
# Check that we don't see any /openai/ routes (which are deprecated)
openai_routes = [r for r in routes if "/openai/" in r.route]
assert len(openai_routes) == 0, "Default filter should not include deprecated /openai/ routes"
# Should see standard v1 routes like /inspect/routes, /health, /version
paths = [r.route for r in routes]
assert "/inspect/routes" in paths or "/v1/inspect/routes" in paths
assert "/health" in paths or "/v1/health" in paths
@pytest.mark.skip(reason="inspect tests disabled")
def test_list_routes_filter_by_deprecated(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
"""Test list_routes with deprecated filter."""
response = llama_stack_client.routes.list(api_filter="deprecated")
assert response is not None
assert hasattr(response, "data")
routes = response.data
# When filtering for deprecated, we should get deprecated routes
# At minimum, we should see some /openai/ routes which are deprecated
if len(routes) > 0:
# If there are any deprecated routes, they should include openai routes
openai_routes = [r for r in routes if "/openai/" in r.route]
assert len(openai_routes) > 0, "Deprecated filter should include /openai/ routes"
@pytest.mark.skip(reason="inspect tests disabled")
def test_list_routes_filter_by_v1(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
"""Test list_routes with v1 filter."""
response = llama_stack_client.routes.list(api_filter="v1")
assert response is not None
assert hasattr(response, "data")
routes = response.data
assert len(routes) > 0
# Should not include deprecated routes
openai_routes = [r for r in routes if "/openai/" in r.route]
assert len(openai_routes) == 0
# Should include v1 routes
paths = [r.route for r in routes]
assert any(
"/v1/" in p or p.startswith("/inspect/") or p.startswith("/health") or p.startswith("/version")
for p in paths
)

View file

@ -10,7 +10,6 @@ import os
import pytest
import llama_stack.core.telemetry.telemetry as telemetry_module
from llama_stack.testing.api_recorder import patch_httpx_for_test_id
from tests.integration.fixtures.common import instantiate_llama_stack_client
from tests.integration.telemetry.collectors import InMemoryTelemetryManager, OtlpHttpTestCollector
@ -22,40 +21,26 @@ def telemetry_test_collector():
stack_mode = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
if stack_mode == "server":
# In server mode, the collector must be started and the server is already running.
# The integration test script (scripts/integration-tests.sh) should have set
# LLAMA_STACK_TEST_COLLECTOR_PORT and OTEL_EXPORTER_OTLP_ENDPOINT before starting the server.
try:
collector = OtlpHttpTestCollector()
except RuntimeError as exc:
pytest.skip(str(exc))
env_overrides = {
"OTEL_EXPORTER_OTLP_ENDPOINT": collector.endpoint,
"OTEL_EXPORTER_OTLP_PROTOCOL": "http/protobuf",
"OTEL_BSP_SCHEDULE_DELAY": "200",
"OTEL_BSP_EXPORT_TIMEOUT": "2000",
"LLAMA_STACK_DISABLE_GUNICORN": "true", # Disable multi-process for telemetry collection
}
previous_env = {key: os.environ.get(key) for key in env_overrides}
previous_force_restart = os.environ.get("LLAMA_STACK_TEST_FORCE_SERVER_RESTART")
for key, value in env_overrides.items():
os.environ[key] = value
os.environ["LLAMA_STACK_TEST_FORCE_SERVER_RESTART"] = "1"
telemetry_module._TRACER_PROVIDER = None
# Verify the collector is listening on the expected endpoint
expected_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
if expected_endpoint and collector.endpoint != expected_endpoint:
pytest.skip(
f"Collector endpoint mismatch: expected {expected_endpoint}, got {collector.endpoint}. "
"Server was likely started before collector."
)
try:
yield collector
finally:
collector.shutdown()
for key, prior in previous_env.items():
if prior is None:
os.environ.pop(key, None)
else:
os.environ[key] = prior
if previous_force_restart is None:
os.environ.pop("LLAMA_STACK_TEST_FORCE_SERVER_RESTART", None)
else:
os.environ["LLAMA_STACK_TEST_FORCE_SERVER_RESTART"] = previous_force_restart
else:
manager = InMemoryTelemetryManager()
try: