mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Update tests
This commit is contained in:
parent
d78e30fe8b
commit
f66718be80
2 changed files with 19 additions and 12 deletions
|
@ -153,10 +153,17 @@ def client_with_models(
|
|||
metadata={"embedding_dimension": embedding_dimension or 384},
|
||||
)
|
||||
if rerank_model_id and rerank_model_id not in model_ids:
|
||||
rerank_provider = providers[0]
|
||||
selected_provider = None
|
||||
for p in providers:
|
||||
# Currently only NVIDIA inference provider supports reranking
|
||||
if p.provider_type == "remote::nvidia":
|
||||
selected_provider = p
|
||||
break
|
||||
|
||||
selected_provider = selected_provider or providers[0]
|
||||
client.models.register(
|
||||
model_id=rerank_model_id,
|
||||
provider_id=rerank_provider.provider_id,
|
||||
provider_id=selected_provider.provider_id,
|
||||
model_type="rerank",
|
||||
)
|
||||
return client
|
||||
|
|
|
@ -67,11 +67,11 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None:
|
|||
"mixed-content-2",
|
||||
],
|
||||
)
|
||||
def test_rerank_text(llama_stack_client, rerank_model_id, query, items, inference_provider_type):
|
||||
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||
|
||||
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
assert isinstance(response, RerankResponse)
|
||||
assert len(response.data) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
@ -94,32 +94,32 @@ def test_rerank_text(llama_stack_client, rerank_model_id, query, items, inferenc
|
|||
"mixed-content-2",
|
||||
],
|
||||
)
|
||||
def test_rerank_image(llama_stack_client, rerank_model_id, query, items, inference_provider_type):
|
||||
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||
|
||||
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
|
||||
error_type = (
|
||||
ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||
)
|
||||
with pytest.raises(error_type):
|
||||
llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
else:
|
||||
response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
assert isinstance(response, RerankResponse)
|
||||
assert len(response.data) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
def test_rerank_max_results(llama_stack_client, rerank_model_id, inference_provider_type):
|
||||
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
|
||||
max_num_results = 2
|
||||
|
||||
response = llama_stack_client.inference.rerank(
|
||||
response = client_with_models.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
|
@ -131,12 +131,12 @@ def test_rerank_max_results(llama_stack_client, rerank_model_id, inference_provi
|
|||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
def test_rerank_max_results_larger_than_items(llama_stack_client, rerank_model_id, inference_provider_type):
|
||||
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank yet")
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2]
|
||||
response = llama_stack_client.inference.rerank(
|
||||
response = client_with_models.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue