test: cleanup embedding model test suite (#1322)

# What does this PR do?

 - skip media tests for models that do not support media
 - skip output_dimension tests for models that do not support it
 - skip task_type tests for models that do not support it
 - provide task_type for models that require it

## Test Plan

`LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v
tests/client-sdk/inference/test_embedding.py --embedding-model ...`
This commit is contained in:
Matthew Farrellee 2025-02-28 12:02:36 -06:00 committed by GitHub
parent c91548fe07
commit 83dc8fbdff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -76,6 +76,25 @@ DUMMY_IMAGE_URL = ImageContentItem(
) )
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image") DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
SUPPORTED_PROVIDERS = {"remote::nvidia"} SUPPORTED_PROVIDERS = {"remote::nvidia"}
MODELS_SUPPORTING_MEDIA = {}
MODELS_SUPPORTING_OUTPUT_DIMENSION = {"nvidia/llama-3.2-nv-embedqa-1b-v2"}
MODELS_REQUIRING_TASK_TYPE = {
"nvidia/llama-3.2-nv-embedqa-1b-v2",
"nvidia/nv-embedqa-e5-v5",
"nvidia/nv-embedqa-mistral-7b-v2",
"snowflake/arctic-embed-l",
}
MODELS_SUPPORTING_TASK_TYPE = MODELS_REQUIRING_TASK_TYPE
def default_task_type(model_id):
"""
Some models require a task type parameter. This provides a default value for
testing those models.
"""
if model_id in MODELS_REQUIRING_TASK_TYPE:
return {"task_type": "query"}
return {}
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -92,7 +111,9 @@ SUPPORTED_PROVIDERS = {"remote::nvidia"}
def test_embedding_text(llama_stack_client, embedding_model_id, contents, inference_provider_type): def test_embedding_text(llama_stack_client, embedding_model_id, contents, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
)
assert isinstance(response, EmbeddingsResponse) assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
assert isinstance(response.embeddings[0], list) assert isinstance(response.embeddings[0], list)
@ -110,11 +131,14 @@ def test_embedding_text(llama_stack_client, embedding_model_id, contents, infere
"list[url,string,base64,text]", "list[url,string,base64,text]",
], ],
) )
@pytest.mark.xfail(reason="Media is not supported")
def test_embedding_image(llama_stack_client, embedding_model_id, contents, inference_provider_type): def test_embedding_image(llama_stack_client, embedding_model_id, contents, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) if embedding_model_id not in MODELS_SUPPORTING_MEDIA:
pytest.xfail(f"{embedding_model_id} doesn't support media")
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
)
assert isinstance(response, EmbeddingsResponse) assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
assert isinstance(response.embeddings[0], list) assert isinstance(response.embeddings[0], list)
@ -145,7 +169,10 @@ def test_embedding_truncation(
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings( response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, text_truncation=text_truncation model_id=embedding_model_id,
contents=contents,
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
) )
assert isinstance(response, EmbeddingsResponse) assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1 assert len(response.embeddings) == 1
@ -178,26 +205,36 @@ def test_embedding_truncation_error(
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError): with pytest.raises(BadRequestError):
llama_stack_client.inference.embeddings( llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation model_id=embedding_model_id,
contents=[DUMMY_LONG_TEXT],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
) )
@pytest.mark.xfail(reason="Only valid for model supporting dimension reduction")
def test_embedding_output_dimension(llama_stack_client, embedding_model_id, inference_provider_type): def test_embedding_output_dimension(llama_stack_client, embedding_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
base_response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=[DUMMY_STRING]) if embedding_model_id not in MODELS_SUPPORTING_OUTPUT_DIMENSION:
pytest.xfail(f"{embedding_model_id} doesn't support output_dimension")
base_response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], **default_task_type(embedding_model_id)
)
test_response = llama_stack_client.inference.embeddings( test_response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], output_dimension=32 model_id=embedding_model_id,
contents=[DUMMY_STRING],
**default_task_type(embedding_model_id),
output_dimension=32,
) )
assert len(base_response.embeddings[0]) != len(test_response.embeddings[0]) assert len(base_response.embeddings[0]) != len(test_response.embeddings[0])
assert len(test_response.embeddings[0]) == 32 assert len(test_response.embeddings[0]) == 32
@pytest.mark.xfail(reason="Only valid for model supporting task type")
def test_embedding_task_type(llama_stack_client, embedding_model_id, inference_provider_type): def test_embedding_task_type(llama_stack_client, embedding_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
if embedding_model_id not in MODELS_SUPPORTING_TASK_TYPE:
pytest.xfail(f"{embedding_model_id} doesn't support task_type")
query_embedding = llama_stack_client.inference.embeddings( query_embedding = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query" model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query"
) )
@ -220,7 +257,10 @@ def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings( response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation model_id=embedding_model_id,
contents=[DUMMY_STRING],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
) )
assert isinstance(response, EmbeddingsResponse) assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1 assert len(response.embeddings) == 1
@ -245,5 +285,8 @@ def test_embedding_text_truncation_error(
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError): with pytest.raises(BadRequestError):
llama_stack_client.inference.embeddings( llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation model_id=embedding_model_id,
contents=[DUMMY_STRING],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
) )