diff --git a/tests/client-sdk/inference/test_embedding.py b/tests/client-sdk/inference/test_embedding.py index 69d35d05d..075f927f7 100644 --- a/tests/client-sdk/inference/test_embedding.py +++ b/tests/client-sdk/inference/test_embedding.py @@ -76,6 +76,25 @@ DUMMY_IMAGE_URL = ImageContentItem( ) DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image") SUPPORTED_PROVIDERS = {"remote::nvidia"} +MODELS_SUPPORTING_MEDIA = {} +MODELS_SUPPORTING_OUTPUT_DIMENSION = {"nvidia/llama-3.2-nv-embedqa-1b-v2"} +MODELS_REQUIRING_TASK_TYPE = { + "nvidia/llama-3.2-nv-embedqa-1b-v2", + "nvidia/nv-embedqa-e5-v5", + "nvidia/nv-embedqa-mistral-7b-v2", + "snowflake/arctic-embed-l", +} +MODELS_SUPPORTING_TASK_TYPE = MODELS_REQUIRING_TASK_TYPE + + +def default_task_type(model_id): + """ + Some models require a task type parameter. This provides a default value for + testing those models. + """ + if model_id in MODELS_REQUIRING_TASK_TYPE: + return {"task_type": "query"} + return {} @pytest.mark.parametrize( @@ -92,7 +111,9 @@ SUPPORTED_PROVIDERS = {"remote::nvidia"} def test_embedding_text(llama_stack_client, embedding_model_id, contents, inference_provider_type): if inference_provider_type not in SUPPORTED_PROVIDERS: pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") - response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) + response = llama_stack_client.inference.embeddings( + model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id) + ) assert isinstance(response, EmbeddingsResponse) assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) assert isinstance(response.embeddings[0], list) @@ -110,11 +131,14 @@ def test_embedding_text(llama_stack_client, embedding_model_id, contents, infere "list[url,string,base64,text]", ], ) -@pytest.mark.xfail(reason="Media is not supported") def test_embedding_image(llama_stack_client, embedding_model_id, contents, inference_provider_type): if inference_provider_type not in SUPPORTED_PROVIDERS: pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") - response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) + if embedding_model_id not in MODELS_SUPPORTING_MEDIA: + pytest.xfail(f"{embedding_model_id} doesn't support media") + response = llama_stack_client.inference.embeddings( + model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id) + ) assert isinstance(response, EmbeddingsResponse) assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) assert isinstance(response.embeddings[0], list) @@ -145,7 +169,10 @@ def test_embedding_truncation( if inference_provider_type not in SUPPORTED_PROVIDERS: pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") response = llama_stack_client.inference.embeddings( - model_id=embedding_model_id, contents=contents, text_truncation=text_truncation + model_id=embedding_model_id, + contents=contents, + text_truncation=text_truncation, + **default_task_type(embedding_model_id), ) assert isinstance(response, EmbeddingsResponse) assert len(response.embeddings) == 1 @@ -178,26 +205,36 @@ def test_embedding_truncation_error( pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") with pytest.raises(BadRequestError): llama_stack_client.inference.embeddings( - model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation + model_id=embedding_model_id, + contents=[DUMMY_LONG_TEXT], + text_truncation=text_truncation, + **default_task_type(embedding_model_id), ) -@pytest.mark.xfail(reason="Only valid for model supporting dimension reduction") def test_embedding_output_dimension(llama_stack_client, embedding_model_id, inference_provider_type): if inference_provider_type not in SUPPORTED_PROVIDERS: pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") - base_response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=[DUMMY_STRING]) + if embedding_model_id not in MODELS_SUPPORTING_OUTPUT_DIMENSION: + pytest.xfail(f"{embedding_model_id} doesn't support output_dimension") + base_response = llama_stack_client.inference.embeddings( + model_id=embedding_model_id, contents=[DUMMY_STRING], **default_task_type(embedding_model_id) + ) test_response = llama_stack_client.inference.embeddings( - model_id=embedding_model_id, contents=[DUMMY_STRING], output_dimension=32 + model_id=embedding_model_id, + contents=[DUMMY_STRING], + **default_task_type(embedding_model_id), + output_dimension=32, ) assert len(base_response.embeddings[0]) != len(test_response.embeddings[0]) assert len(test_response.embeddings[0]) == 32 -@pytest.mark.xfail(reason="Only valid for model supporting task type") def test_embedding_task_type(llama_stack_client, embedding_model_id, inference_provider_type): if inference_provider_type not in SUPPORTED_PROVIDERS: pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") + if embedding_model_id not in MODELS_SUPPORTING_TASK_TYPE: + pytest.xfail(f"{embedding_model_id} doesn't support task_type") query_embedding = llama_stack_client.inference.embeddings( model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query" ) @@ -220,7 +257,10 @@ def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_ if inference_provider_type not in SUPPORTED_PROVIDERS: pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") response = llama_stack_client.inference.embeddings( - model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation + model_id=embedding_model_id, + contents=[DUMMY_STRING], + text_truncation=text_truncation, + **default_task_type(embedding_model_id), ) assert isinstance(response, EmbeddingsResponse) assert len(response.embeddings) == 1 @@ -245,5 +285,8 @@ def test_embedding_text_truncation_error( pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") with pytest.raises(BadRequestError): llama_stack_client.inference.embeddings( - model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation + model_id=embedding_model_id, + contents=[DUMMY_STRING], + text_truncation=text_truncation, + **default_task_type(embedding_model_id), )