mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Fix rerank integration test based on client side changes
This commit is contained in:
parent
bb2eb33fc3
commit
6b4940806f
8 changed files with 27 additions and 276 deletions
|
@ -6,7 +6,7 @@
|
|||
|
||||
import pytest
|
||||
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||
from llama_stack_client.types import InferenceRerankResponse
|
||||
from llama_stack_client.types.alpha import InferenceRerankResponse
|
||||
from llama_stack_client.types.shared.interleaved_content import (
|
||||
ImageContentItem,
|
||||
ImageContentItemImage,
|
||||
|
@ -97,7 +97,7 @@ def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, e
|
|||
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
assert isinstance(response, list)
|
||||
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
|
||||
assert len(response) <= len(items)
|
||||
|
@ -129,9 +129,9 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen
|
|||
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||
)
|
||||
with pytest.raises(error_type):
|
||||
client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
else:
|
||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) <= len(items)
|
||||
|
@ -144,7 +144,7 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi
|
|||
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
|
||||
max_num_results = 2
|
||||
|
||||
response = client_with_models.inference.rerank(
|
||||
response = client_with_models.alpha.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
|
@ -160,7 +160,7 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i
|
|||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2]
|
||||
response = client_with_models.inference.rerank(
|
||||
response = client_with_models.alpha.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
|
@ -208,7 +208,7 @@ def test_rerank_semantic_correctness(
|
|||
):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
_validate_rerank_response(response, items)
|
||||
_validate_semantic_ranking(response, items, expected_first_item)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue