Update tests

This commit is contained in:
Jiayi 2025-09-09 16:21:00 -07:00
parent d78e30fe8b
commit f66718be80
2 changed files with 19 additions and 12 deletions

View file

@ -153,10 +153,17 @@ def client_with_models(
metadata={"embedding_dimension": embedding_dimension or 384}, metadata={"embedding_dimension": embedding_dimension or 384},
) )
if rerank_model_id and rerank_model_id not in model_ids: if rerank_model_id and rerank_model_id not in model_ids:
rerank_provider = providers[0] selected_provider = None
for p in providers:
# Currently only NVIDIA inference provider supports reranking
if p.provider_type == "remote::nvidia":
selected_provider = p
break
selected_provider = selected_provider or providers[0]
client.models.register( client.models.register(
model_id=rerank_model_id, model_id=rerank_model_id,
provider_id=rerank_provider.provider_id, provider_id=selected_provider.provider_id,
model_type="rerank", model_type="rerank",
) )
return client return client

View file

@ -67,11 +67,11 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None:
"mixed-content-2", "mixed-content-2",
], ],
) )
def test_rerank_text(llama_stack_client, rerank_model_id, query, items, inference_provider_type): def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ") pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items) response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, RerankResponse) assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items) assert len(response.data) <= len(items)
_validate_rerank_response(response, items) _validate_rerank_response(response, items)
@ -94,32 +94,32 @@ def test_rerank_text(llama_stack_client, rerank_model_id, query, items, inferenc
"mixed-content-2", "mixed-content-2",
], ],
) )
def test_rerank_image(llama_stack_client, rerank_model_id, query, items, inference_provider_type): def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ") pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA: if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
error_type = ( error_type = (
ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
) )
with pytest.raises(error_type): with pytest.raises(error_type):
llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items) client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
else: else:
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items) response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, RerankResponse) assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items) assert len(response.data) <= len(items)
_validate_rerank_response(response, items) _validate_rerank_response(response, items)
def test_rerank_max_results(llama_stack_client, rerank_model_id, inference_provider_type): def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ") pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2] items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
max_num_results = 2 max_num_results = 2
response = llama_stack_client.inference.rerank( response = client_with_models.inference.rerank(
model=rerank_model_id, model=rerank_model_id,
query=DUMMY_STRING, query=DUMMY_STRING,
items=items, items=items,
@ -131,12 +131,12 @@ def test_rerank_max_results(llama_stack_client, rerank_model_id, inference_provi
_validate_rerank_response(response, items) _validate_rerank_response(response, items)
def test_rerank_max_results_larger_than_items(llama_stack_client, rerank_model_id, inference_provider_type): def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank yet") pytest.xfail(f"{inference_provider_type} doesn't support rerank yet")
items = [DUMMY_STRING, DUMMY_STRING2] items = [DUMMY_STRING, DUMMY_STRING2]
response = llama_stack_client.inference.rerank( response = client_with_models.inference.rerank(
model=rerank_model_id, model=rerank_model_id,
query=DUMMY_STRING, query=DUMMY_STRING,
items=items, items=items,