forked from phoenix-oss/llama-stack-mirror
test: cleanup embedding model test suite (#1322)
# What does this PR do? - skip media tests for models that do not support media - skip output_dimension tests for models that do not support it - skip task_type tests for models that do not support it - provide task_type for models that require it ## Test Plan `LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/inference/test_embedding.py --embedding-model ...`
This commit is contained in:
parent
c91548fe07
commit
83dc8fbdff
1 changed files with 54 additions and 11 deletions
|
@ -76,6 +76,25 @@ DUMMY_IMAGE_URL = ImageContentItem(
|
||||||
)
|
)
|
||||||
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
||||||
SUPPORTED_PROVIDERS = {"remote::nvidia"}
|
SUPPORTED_PROVIDERS = {"remote::nvidia"}
|
||||||
|
MODELS_SUPPORTING_MEDIA = {}
|
||||||
|
MODELS_SUPPORTING_OUTPUT_DIMENSION = {"nvidia/llama-3.2-nv-embedqa-1b-v2"}
|
||||||
|
MODELS_REQUIRING_TASK_TYPE = {
|
||||||
|
"nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||||
|
"nvidia/nv-embedqa-e5-v5",
|
||||||
|
"nvidia/nv-embedqa-mistral-7b-v2",
|
||||||
|
"snowflake/arctic-embed-l",
|
||||||
|
}
|
||||||
|
MODELS_SUPPORTING_TASK_TYPE = MODELS_REQUIRING_TASK_TYPE
|
||||||
|
|
||||||
|
|
||||||
|
def default_task_type(model_id):
|
||||||
|
"""
|
||||||
|
Some models require a task type parameter. This provides a default value for
|
||||||
|
testing those models.
|
||||||
|
"""
|
||||||
|
if model_id in MODELS_REQUIRING_TASK_TYPE:
|
||||||
|
return {"task_type": "query"}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -92,7 +111,9 @@ SUPPORTED_PROVIDERS = {"remote::nvidia"}
|
||||||
def test_embedding_text(llama_stack_client, embedding_model_id, contents, inference_provider_type):
|
def test_embedding_text(llama_stack_client, embedding_model_id, contents, inference_provider_type):
|
||||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||||
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents)
|
response = llama_stack_client.inference.embeddings(
|
||||||
|
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
|
||||||
|
)
|
||||||
assert isinstance(response, EmbeddingsResponse)
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
|
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
|
||||||
assert isinstance(response.embeddings[0], list)
|
assert isinstance(response.embeddings[0], list)
|
||||||
|
@ -110,11 +131,14 @@ def test_embedding_text(llama_stack_client, embedding_model_id, contents, infere
|
||||||
"list[url,string,base64,text]",
|
"list[url,string,base64,text]",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.xfail(reason="Media is not supported")
|
|
||||||
def test_embedding_image(llama_stack_client, embedding_model_id, contents, inference_provider_type):
|
def test_embedding_image(llama_stack_client, embedding_model_id, contents, inference_provider_type):
|
||||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||||
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents)
|
if embedding_model_id not in MODELS_SUPPORTING_MEDIA:
|
||||||
|
pytest.xfail(f"{embedding_model_id} doesn't support media")
|
||||||
|
response = llama_stack_client.inference.embeddings(
|
||||||
|
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
|
||||||
|
)
|
||||||
assert isinstance(response, EmbeddingsResponse)
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
|
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
|
||||||
assert isinstance(response.embeddings[0], list)
|
assert isinstance(response.embeddings[0], list)
|
||||||
|
@ -145,7 +169,10 @@ def test_embedding_truncation(
|
||||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||||
response = llama_stack_client.inference.embeddings(
|
response = llama_stack_client.inference.embeddings(
|
||||||
model_id=embedding_model_id, contents=contents, text_truncation=text_truncation
|
model_id=embedding_model_id,
|
||||||
|
contents=contents,
|
||||||
|
text_truncation=text_truncation,
|
||||||
|
**default_task_type(embedding_model_id),
|
||||||
)
|
)
|
||||||
assert isinstance(response, EmbeddingsResponse)
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
assert len(response.embeddings) == 1
|
assert len(response.embeddings) == 1
|
||||||
|
@ -178,26 +205,36 @@ def test_embedding_truncation_error(
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||||
with pytest.raises(BadRequestError):
|
with pytest.raises(BadRequestError):
|
||||||
llama_stack_client.inference.embeddings(
|
llama_stack_client.inference.embeddings(
|
||||||
model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation
|
model_id=embedding_model_id,
|
||||||
|
contents=[DUMMY_LONG_TEXT],
|
||||||
|
text_truncation=text_truncation,
|
||||||
|
**default_task_type(embedding_model_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Only valid for model supporting dimension reduction")
|
|
||||||
def test_embedding_output_dimension(llama_stack_client, embedding_model_id, inference_provider_type):
|
def test_embedding_output_dimension(llama_stack_client, embedding_model_id, inference_provider_type):
|
||||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||||
base_response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=[DUMMY_STRING])
|
if embedding_model_id not in MODELS_SUPPORTING_OUTPUT_DIMENSION:
|
||||||
|
pytest.xfail(f"{embedding_model_id} doesn't support output_dimension")
|
||||||
|
base_response = llama_stack_client.inference.embeddings(
|
||||||
|
model_id=embedding_model_id, contents=[DUMMY_STRING], **default_task_type(embedding_model_id)
|
||||||
|
)
|
||||||
test_response = llama_stack_client.inference.embeddings(
|
test_response = llama_stack_client.inference.embeddings(
|
||||||
model_id=embedding_model_id, contents=[DUMMY_STRING], output_dimension=32
|
model_id=embedding_model_id,
|
||||||
|
contents=[DUMMY_STRING],
|
||||||
|
**default_task_type(embedding_model_id),
|
||||||
|
output_dimension=32,
|
||||||
)
|
)
|
||||||
assert len(base_response.embeddings[0]) != len(test_response.embeddings[0])
|
assert len(base_response.embeddings[0]) != len(test_response.embeddings[0])
|
||||||
assert len(test_response.embeddings[0]) == 32
|
assert len(test_response.embeddings[0]) == 32
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Only valid for model supporting task type")
|
|
||||||
def test_embedding_task_type(llama_stack_client, embedding_model_id, inference_provider_type):
|
def test_embedding_task_type(llama_stack_client, embedding_model_id, inference_provider_type):
|
||||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||||
|
if embedding_model_id not in MODELS_SUPPORTING_TASK_TYPE:
|
||||||
|
pytest.xfail(f"{embedding_model_id} doesn't support task_type")
|
||||||
query_embedding = llama_stack_client.inference.embeddings(
|
query_embedding = llama_stack_client.inference.embeddings(
|
||||||
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query"
|
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query"
|
||||||
)
|
)
|
||||||
|
@ -220,7 +257,10 @@ def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_
|
||||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||||
response = llama_stack_client.inference.embeddings(
|
response = llama_stack_client.inference.embeddings(
|
||||||
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation
|
model_id=embedding_model_id,
|
||||||
|
contents=[DUMMY_STRING],
|
||||||
|
text_truncation=text_truncation,
|
||||||
|
**default_task_type(embedding_model_id),
|
||||||
)
|
)
|
||||||
assert isinstance(response, EmbeddingsResponse)
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
assert len(response.embeddings) == 1
|
assert len(response.embeddings) == 1
|
||||||
|
@ -245,5 +285,8 @@ def test_embedding_text_truncation_error(
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||||
with pytest.raises(BadRequestError):
|
with pytest.raises(BadRequestError):
|
||||||
llama_stack_client.inference.embeddings(
|
llama_stack_client.inference.embeddings(
|
||||||
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation
|
model_id=embedding_model_id,
|
||||||
|
contents=[DUMMY_STRING],
|
||||||
|
text_truncation=text_truncation,
|
||||||
|
**default_task_type(embedding_model_id),
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue