mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 04:22:35 +00:00
fix: use extra_body for passing input_type params for asymmetric embedding models for NVIDIA Inference Provider
This commit is contained in:
parent
007efa6eb5
commit
1d4d263d57
3 changed files with 235 additions and 120 deletions
|
|
@ -21,6 +21,16 @@ def decode_base64_to_floats(base64_string: str) -> list[float]:
|
|||
return list(embedding_floats)
|
||||
|
||||
|
||||
ASYMMETRIC_EMBEDDING_MODELS_BY_PROVIDER = {
|
||||
"remote::nvidia": [
|
||||
"nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
"nvidia/nv-embedqa-e5-v5",
|
||||
"nvidia/nv-embedqa-mistral-7b-v2",
|
||||
"snowflake/arctic-embed-l",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def provider_from_model(client_with_models, model_id):
|
||||
models = {m.identifier: m for m in client_with_models.models.list()}
|
||||
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
|
||||
|
|
@ -29,6 +39,25 @@ def provider_from_model(client_with_models, model_id):
|
|||
return providers[provider_id]
|
||||
|
||||
|
||||
def is_asymmetric_model(client_with_models, model_id):
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
provider_type = provider.provider_type
|
||||
|
||||
if provider_type not in ASYMMETRIC_EMBEDDING_MODELS_BY_PROVIDER:
|
||||
return False
|
||||
|
||||
return model_id in ASYMMETRIC_EMBEDDING_MODELS_BY_PROVIDER[provider_type]
|
||||
|
||||
|
||||
def get_extra_body_for_model(client_with_models, model_id, input_type="query"):
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
|
||||
if provider.provider_type == "remote::nvidia":
|
||||
return {"input_type": input_type}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_user_param(client, model_id):
|
||||
provider = provider_from_model(client, model_id)
|
||||
if provider.provider_type in (
|
||||
|
|
@ -40,17 +69,29 @@ def skip_if_model_doesnt_support_user_param(client, model_id):
|
|||
|
||||
def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
|
||||
provider = provider_from_model(client, model_id)
|
||||
if provider.provider_type in (
|
||||
|
||||
should_skip = provider.provider_type in (
|
||||
"remote::databricks", # param silently ignored, always returns floats
|
||||
"remote::fireworks", # param silently ignored, always returns list of floats
|
||||
"remote::ollama", # param silently ignored, always returns list of floats
|
||||
):
|
||||
) or (
|
||||
provider.provider_type == "remote::nvidia"
|
||||
and model_id
|
||||
in [
|
||||
"nvidia/nv-embedqa-e5-v5",
|
||||
"nvidia/nv-embedqa-mistral-7b-v2",
|
||||
"snowflake/arctic-embed-l",
|
||||
]
|
||||
)
|
||||
|
||||
if should_skip:
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_id):
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if (
|
||||
|
||||
should_skip = (
|
||||
provider.provider_type
|
||||
in (
|
||||
"remote::together", # returns 400
|
||||
|
|
@ -59,11 +100,19 @@ def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_i
|
|||
"remote::databricks",
|
||||
"remote::watsonx", # openai.BadRequestError: Error code: 400 - {'detail': "litellm.UnsupportedParamsError: watsonx does not support parameters: {'dimensions': 384}
|
||||
)
|
||||
):
|
||||
pytest.skip(
|
||||
f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."
|
||||
or (provider.provider_type == "remote::openai" and "text-embedding-3" not in model_id)
|
||||
or (
|
||||
provider.provider_type == "remote::nvidia"
|
||||
and model_id
|
||||
in [
|
||||
"nvidia/nv-embedqa-e5-v5",
|
||||
"nvidia/nv-embedqa-mistral-7b-v2",
|
||||
"snowflake/arctic-embed-l",
|
||||
]
|
||||
)
|
||||
if provider.provider_type == "remote::openai" and "text-embedding-3" not in model_id:
|
||||
)
|
||||
|
||||
if should_skip:
|
||||
pytest.skip(
|
||||
f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."
|
||||
)
|
||||
|
|
@ -100,12 +149,27 @@ def test_openai_embeddings_single_string(compat_client, client_with_models, embe
|
|||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_text = "Hello, world!"
|
||||
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
|
||||
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
encoding_format="float",
|
||||
)
|
||||
# For asymmetric models, verify that calling without extra_body raises an error
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs_without_extra = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_text,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
compat_client.embeddings.create(**kwargs_without_extra)
|
||||
|
||||
kwargs = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_text,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs["extra_body"] = extra_body
|
||||
|
||||
response = compat_client.embeddings.create(**kwargs)
|
||||
|
||||
assert response.object == "list"
|
||||
|
||||
|
|
@ -124,12 +188,26 @@ def test_openai_embeddings_multiple_strings(compat_client, client_with_models, e
|
|||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_texts = ["Hello, world!", "How are you today?", "This is a test."]
|
||||
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
|
||||
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_texts,
|
||||
encoding_format="float",
|
||||
)
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs_without_extra = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_texts,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
compat_client.embeddings.create(**kwargs_without_extra)
|
||||
|
||||
kwargs = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_texts,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs["extra_body"] = extra_body
|
||||
|
||||
response = compat_client.embeddings.create(**kwargs)
|
||||
|
||||
assert response.object == "list"
|
||||
|
||||
|
|
@ -150,12 +228,26 @@ def test_openai_embeddings_with_encoding_format_float(compat_client, client_with
|
|||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_text = "Test encoding format"
|
||||
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
|
||||
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
encoding_format="float",
|
||||
)
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs_without_extra = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_text,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
compat_client.embeddings.create(**kwargs_without_extra)
|
||||
|
||||
kwargs = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_text,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs["extra_body"] = extra_body
|
||||
|
||||
response = compat_client.embeddings.create(**kwargs)
|
||||
|
||||
assert response.object == "list"
|
||||
assert len(response.data) == 1
|
||||
|
|
@ -170,12 +262,26 @@ def test_openai_embeddings_with_dimensions(compat_client, client_with_models, em
|
|||
|
||||
input_text = "Test dimensions parameter"
|
||||
dimensions = 16
|
||||
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
|
||||
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
dimensions=dimensions,
|
||||
)
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs_without_extra = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_text,
|
||||
"dimensions": dimensions,
|
||||
}
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
compat_client.embeddings.create(**kwargs_without_extra)
|
||||
|
||||
kwargs = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_text,
|
||||
"dimensions": dimensions,
|
||||
}
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs["extra_body"] = extra_body
|
||||
|
||||
response = compat_client.embeddings.create(**kwargs)
|
||||
|
||||
assert response.object == "list"
|
||||
assert len(response.data) == 1
|
||||
|
|
@ -191,12 +297,26 @@ def test_openai_embeddings_with_user_parameter(compat_client, client_with_models
|
|||
|
||||
input_text = "Test user parameter"
|
||||
user_id = "test-user-123"
|
||||
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
|
||||
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
user=user_id,
|
||||
)
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs_without_extra = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_text,
|
||||
"user": user_id,
|
||||
}
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
compat_client.embeddings.create(**kwargs_without_extra)
|
||||
|
||||
kwargs = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_text,
|
||||
"user": user_id,
|
||||
}
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs["extra_body"] = extra_body
|
||||
|
||||
response = compat_client.embeddings.create(**kwargs)
|
||||
|
||||
assert response.object == "list"
|
||||
assert len(response.data) == 1
|
||||
|
|
@ -208,11 +328,17 @@ def test_openai_embeddings_empty_list_error(compat_client, client_with_models, e
|
|||
"""Test that empty list input raises an appropriate error."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
|
||||
|
||||
kwargs = {
|
||||
"model": embedding_model_id,
|
||||
"input": [],
|
||||
}
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs["extra_body"] = extra_body
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=[],
|
||||
)
|
||||
compat_client.embeddings.create(**kwargs)
|
||||
|
||||
|
||||
def test_openai_embeddings_invalid_model_error(compat_client, client_with_models, embedding_model_id):
|
||||
|
|
@ -232,18 +358,35 @@ def test_openai_embeddings_different_inputs_different_outputs(compat_client, cli
|
|||
|
||||
input_text1 = "This is the first text"
|
||||
input_text2 = "This is completely different content"
|
||||
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
|
||||
|
||||
response1 = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text1,
|
||||
encoding_format="float",
|
||||
)
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs_without_extra = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_text1,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
compat_client.embeddings.create(**kwargs_without_extra)
|
||||
|
||||
response2 = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text2,
|
||||
encoding_format="float",
|
||||
)
|
||||
kwargs1 = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_text1,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs1["extra_body"] = extra_body
|
||||
|
||||
kwargs2 = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_text2,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs2["extra_body"] = extra_body
|
||||
|
||||
response1 = compat_client.embeddings.create(**kwargs1)
|
||||
response2 = compat_client.embeddings.create(**kwargs2)
|
||||
|
||||
embedding1 = response1.data[0].embedding
|
||||
embedding2 = response2.data[0].embedding
|
||||
|
|
@ -261,13 +404,28 @@ def test_openai_embeddings_with_encoding_format_base64(compat_client, client_wit
|
|||
|
||||
input_text = "Test base64 encoding format"
|
||||
dimensions = 12
|
||||
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
|
||||
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
encoding_format="base64",
|
||||
dimensions=dimensions,
|
||||
)
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs_without_extra = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_text,
|
||||
"encoding_format": "base64",
|
||||
"dimensions": dimensions,
|
||||
}
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
compat_client.embeddings.create(**kwargs_without_extra)
|
||||
|
||||
kwargs = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_text,
|
||||
"encoding_format": "base64",
|
||||
"dimensions": dimensions,
|
||||
}
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs["extra_body"] = extra_body
|
||||
|
||||
response = compat_client.embeddings.create(**kwargs)
|
||||
|
||||
# Validate response structure
|
||||
assert response.object == "list"
|
||||
|
|
@ -293,12 +451,27 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
|
|||
skip_if_model_doesnt_support_encoding_format_base64(client_with_models, embedding_model_id)
|
||||
|
||||
input_texts = ["First text for base64", "Second text for base64", "Third text for base64"]
|
||||
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
|
||||
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs_without_extra = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_texts,
|
||||
"encoding_format": "base64",
|
||||
}
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
compat_client.embeddings.create(**kwargs_without_extra)
|
||||
|
||||
kwargs = {
|
||||
"model": embedding_model_id,
|
||||
"input": input_texts,
|
||||
"encoding_format": "base64",
|
||||
}
|
||||
if is_asymmetric_model(client_with_models, embedding_model_id):
|
||||
kwargs["extra_body"] = extra_body
|
||||
|
||||
response = compat_client.embeddings.create(**kwargs)
|
||||
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_texts,
|
||||
encoding_format="base64",
|
||||
)
|
||||
# Validate response structure
|
||||
assert response.object == "list"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue