Update tests

This commit is contained in:
Jiayi 2025-09-09 16:21:00 -07:00
parent d78e30fe8b
commit f66718be80
2 changed files with 19 additions and 12 deletions

View file

@ -67,11 +67,11 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None:
"mixed-content-2",
],
)
def test_rerank_text(llama_stack_client, rerank_model_id, query, items, inference_provider_type):
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items)
_validate_rerank_response(response, items)
@ -94,32 +94,32 @@ def test_rerank_text(llama_stack_client, rerank_model_id, query, items, inferenc
"mixed-content-2",
],
)
def test_rerank_image(llama_stack_client, rerank_model_id, query, items, inference_provider_type):
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
error_type = (
ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
)
with pytest.raises(error_type):
llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
else:
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items)
_validate_rerank_response(response, items)
def test_rerank_max_results(llama_stack_client, rerank_model_id, inference_provider_type):
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
max_num_results = 2
response = llama_stack_client.inference.rerank(
response = client_with_models.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
@ -131,12 +131,12 @@ def test_rerank_max_results(llama_stack_client, rerank_model_id, inference_provi
_validate_rerank_response(response, items)
def test_rerank_max_results_larger_than_items(llama_stack_client, rerank_model_id, inference_provider_type):
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank yet")
items = [DUMMY_STRING, DUMMY_STRING2]
response = llama_stack_client.inference.rerank(
response = client_with_models.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,