From f66718be8033b685abbe829753df539fc652e27d Mon Sep 17 00:00:00 2001 From: Jiayi Date: Tue, 9 Sep 2025 16:21:00 -0700 Subject: [PATCH] Update tests --- tests/integration/fixtures/common.py | 11 +++++++++-- tests/integration/inference/test_rerank.py | 20 ++++++++++---------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 27283afe7..8f4c564c8 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -153,10 +153,17 @@ def client_with_models( metadata={"embedding_dimension": embedding_dimension or 384}, ) if rerank_model_id and rerank_model_id not in model_ids: - rerank_provider = providers[0] + selected_provider = None + for p in providers: + # Currently only NVIDIA inference provider supports reranking + if p.provider_type == "remote::nvidia": + selected_provider = p + break + + selected_provider = selected_provider or providers[0] client.models.register( model_id=rerank_model_id, - provider_id=rerank_provider.provider_id, + provider_id=selected_provider.provider_id, model_type="rerank", ) return client diff --git a/tests/integration/inference/test_rerank.py b/tests/integration/inference/test_rerank.py index 0c536b539..27f3074ad 100644 --- a/tests/integration/inference/test_rerank.py +++ b/tests/integration/inference/test_rerank.py @@ -67,11 +67,11 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None: "mixed-content-2", ], ) -def test_rerank_text(llama_stack_client, rerank_model_id, query, items, inference_provider_type): +def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type): if inference_provider_type not in SUPPORTED_PROVIDERS: pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ") - response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items) + response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items) assert isinstance(response, RerankResponse) assert len(response.data) <= len(items) _validate_rerank_response(response, items) @@ -94,32 +94,32 @@ def test_rerank_text(llama_stack_client, rerank_model_id, query, items, inferenc "mixed-content-2", ], ) -def test_rerank_image(llama_stack_client, rerank_model_id, query, items, inference_provider_type): +def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type): if inference_provider_type not in SUPPORTED_PROVIDERS: pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ") if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA: error_type = ( - ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError + ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError ) with pytest.raises(error_type): - llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items) + client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items) else: - response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items) + response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items) assert isinstance(response, RerankResponse) assert len(response.data) <= len(items) _validate_rerank_response(response, items) -def test_rerank_max_results(llama_stack_client, rerank_model_id, inference_provider_type): +def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type): if inference_provider_type not in SUPPORTED_PROVIDERS: pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ") items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2] max_num_results = 2 - response = llama_stack_client.inference.rerank( + response = client_with_models.inference.rerank( model=rerank_model_id, query=DUMMY_STRING, items=items, @@ -131,12 +131,12 @@ def test_rerank_max_results(llama_stack_client, rerank_model_id, inference_provi _validate_rerank_response(response, items) -def test_rerank_max_results_larger_than_items(llama_stack_client, rerank_model_id, inference_provider_type): +def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type): if inference_provider_type not in SUPPORTED_PROVIDERS: pytest.xfail(f"{inference_provider_type} doesn't support rerank yet") items = [DUMMY_STRING, DUMMY_STRING2] - response = llama_stack_client.inference.rerank( + response = client_with_models.inference.rerank( model=rerank_model_id, query=DUMMY_STRING, items=items,