mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
refactor: use extra_body to pass in input_type params for asymmetric embedding models for NVIDIA Inference Provider (#3804)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 1s
Test Llama Stack Build / generate-matrix (push) Successful in 4s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 5s
Test Llama Stack Build / build (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (push) Failing after 9s
API Conformance Tests / check-schema-compatibility (push) Successful in 16s
UI Tests / ui-tests (22) (push) Successful in 33s
Pre-commit / pre-commit (push) Successful in 1m33s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 1s
Test Llama Stack Build / generate-matrix (push) Successful in 4s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 5s
Test Llama Stack Build / build (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (push) Failing after 9s
API Conformance Tests / check-schema-compatibility (push) Successful in 16s
UI Tests / ui-tests (22) (push) Successful in 33s
Pre-commit / pre-commit (push) Successful in 1m33s
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> Previously, the NVIDIA inference provider implemented a custom `openai_embeddings` method with a hardcoded `input_type="query"` parameter, which is required by NVIDIA asymmetric embedding models([https://github.com/llamastack/llama-stack/pull/3205](https://github.com/llamastack/llama-stack/pull/3205)). Recently `extra_body` parameter is added to the embeddings API ([https://github.com/llamastack/llama-stack/pull/3794](https://github.com/llamastack/llama-stack/pull/3794)). So, this PR updates the NVIDIA inference provider to use the base `OpenAIMixin.openai_embeddings` method instead and pass the `input_type` through the `extra_body` parameter for asymmetric embedding models. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Run the following command for the ```embedding_model```: ```nvidia/llama-3.2-nv-embedqa-1b-v2```, ```nvidia/nv-embedqa-e5-v5```, ```nvidia/nv-embedqa-mistral-7b-v2```, and ```snowflake/arctic-embed-l```. ``` pytest -s -v tests/integration/inference/test_openai_embeddings.py --stack-config="inference=nvidia" --embedding-model={embedding_model} --env NVIDIA_API_KEY={nvidia_api_key} --env NVIDIA_BASE_URL="https://integrate.api.nvidia.com" --inference-mode=record ```
This commit is contained in:
parent
866c13cdc2
commit
d875e427bf
3 changed files with 75 additions and 70 deletions
|
|
@ -139,16 +139,13 @@ print(f"Structured Response: {structured_response.choices[0].message.content}")
|
||||||
|
|
||||||
The following example shows how to create embeddings for an NVIDIA NIM.
|
The following example shows how to create embeddings for an NVIDIA NIM.
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
response = client.inference.embeddings(
|
response = client.embeddings.create(
|
||||||
model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
model="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||||
contents=["What is the capital of France?"],
|
input=["What is the capital of France?"],
|
||||||
task_type="query",
|
extra_body={"input_type": "query"},
|
||||||
)
|
)
|
||||||
print(f"Embeddings: {response.embeddings}")
|
print(f"Embeddings: {response.data}")
|
||||||
```
|
```
|
||||||
|
|
||||||
### Vision Language Models Example
|
### Vision Language Models Example
|
||||||
|
|
|
||||||
|
|
@ -5,14 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from openai import NOT_GIVEN
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
OpenAIEmbeddingData,
|
|
||||||
OpenAIEmbeddingsRequestWithExtraBody,
|
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
OpenAIEmbeddingUsage,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
|
|
@ -76,50 +68,3 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||||
:return: The NVIDIA API base URL
|
:return: The NVIDIA API base URL
|
||||||
"""
|
"""
|
||||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||||
|
|
||||||
async def openai_embeddings(
|
|
||||||
self,
|
|
||||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
"""
|
|
||||||
OpenAI-compatible embeddings for NVIDIA NIM.
|
|
||||||
|
|
||||||
Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API.
|
|
||||||
We default this to "query" to ensure requests succeed when using the
|
|
||||||
OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with
|
|
||||||
`task_type='document'`.
|
|
||||||
"""
|
|
||||||
extra_body: dict[str, object] = {"input_type": "query"}
|
|
||||||
logger.warning(
|
|
||||||
"NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. "
|
|
||||||
"For passage embeddings, use the embeddings API with task_type='document'."
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await self.client.embeddings.create(
|
|
||||||
model=await self._get_provider_model_id(params.model),
|
|
||||||
input=params.input,
|
|
||||||
encoding_format=params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
|
|
||||||
dimensions=params.dimensions if params.dimensions is not None else NOT_GIVEN,
|
|
||||||
user=params.user if params.user is not None else NOT_GIVEN,
|
|
||||||
extra_body=extra_body,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = []
|
|
||||||
for i, embedding_data in enumerate(response.data):
|
|
||||||
data.append(
|
|
||||||
OpenAIEmbeddingData(
|
|
||||||
embedding=embedding_data.embedding,
|
|
||||||
index=i,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
usage = OpenAIEmbeddingUsage(
|
|
||||||
prompt_tokens=response.usage.prompt_tokens,
|
|
||||||
total_tokens=response.usage.total_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
return OpenAIEmbeddingsResponse(
|
|
||||||
data=data,
|
|
||||||
model=response.model,
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,15 @@ from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
ASYMMETRIC_EMBEDDING_MODELS_BY_PROVIDER = {
|
||||||
|
"remote::nvidia": [
|
||||||
|
"nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||||
|
"nvidia/nv-embedqa-e5-v5",
|
||||||
|
"nvidia/nv-embedqa-mistral-7b-v2",
|
||||||
|
"snowflake/arctic-embed-l",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def decode_base64_to_floats(base64_string: str) -> list[float]:
|
def decode_base64_to_floats(base64_string: str) -> list[float]:
|
||||||
"""Helper function to decode base64 string to list of float32 values."""
|
"""Helper function to decode base64 string to list of float32 values."""
|
||||||
|
|
@ -29,6 +38,28 @@ def provider_from_model(client_with_models, model_id):
|
||||||
return providers[provider_id]
|
return providers[provider_id]
|
||||||
|
|
||||||
|
|
||||||
|
def is_asymmetric_model(client_with_models, model_id):
|
||||||
|
provider = provider_from_model(client_with_models, model_id)
|
||||||
|
provider_type = provider.provider_type
|
||||||
|
|
||||||
|
if provider_type not in ASYMMETRIC_EMBEDDING_MODELS_BY_PROVIDER:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return model_id in ASYMMETRIC_EMBEDDING_MODELS_BY_PROVIDER[provider_type]
|
||||||
|
|
||||||
|
|
||||||
|
def get_extra_body_for_model(client_with_models, model_id, input_type="query"):
|
||||||
|
if not is_asymmetric_model(client_with_models, model_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
provider = provider_from_model(client_with_models, model_id)
|
||||||
|
|
||||||
|
if provider.provider_type == "remote::nvidia":
|
||||||
|
return {"input_type": input_type}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def skip_if_model_doesnt_support_user_param(client, model_id):
|
def skip_if_model_doesnt_support_user_param(client, model_id):
|
||||||
provider = provider_from_model(client, model_id)
|
provider = provider_from_model(client, model_id)
|
||||||
if provider.provider_type in (
|
if provider.provider_type in (
|
||||||
|
|
@ -40,17 +71,29 @@ def skip_if_model_doesnt_support_user_param(client, model_id):
|
||||||
|
|
||||||
def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
|
def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
|
||||||
provider = provider_from_model(client, model_id)
|
provider = provider_from_model(client, model_id)
|
||||||
if provider.provider_type in (
|
|
||||||
|
should_skip = provider.provider_type in (
|
||||||
"remote::databricks", # param silently ignored, always returns floats
|
"remote::databricks", # param silently ignored, always returns floats
|
||||||
"remote::fireworks", # param silently ignored, always returns list of floats
|
"remote::fireworks", # param silently ignored, always returns list of floats
|
||||||
"remote::ollama", # param silently ignored, always returns list of floats
|
"remote::ollama", # param silently ignored, always returns list of floats
|
||||||
):
|
) or (
|
||||||
|
provider.provider_type == "remote::nvidia"
|
||||||
|
and model_id
|
||||||
|
in [
|
||||||
|
"nvidia/nv-embedqa-e5-v5",
|
||||||
|
"nvidia/nv-embedqa-mistral-7b-v2",
|
||||||
|
"snowflake/arctic-embed-l",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_skip:
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
|
||||||
|
|
||||||
|
|
||||||
def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_id):
|
def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_id):
|
||||||
provider = provider_from_model(client_with_models, model_id)
|
provider = provider_from_model(client_with_models, model_id)
|
||||||
if (
|
|
||||||
|
should_skip = (
|
||||||
provider.provider_type
|
provider.provider_type
|
||||||
in (
|
in (
|
||||||
"remote::together", # returns 400
|
"remote::together", # returns 400
|
||||||
|
|
@ -59,11 +102,19 @@ def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_i
|
||||||
"remote::databricks",
|
"remote::databricks",
|
||||||
"remote::watsonx", # openai.BadRequestError: Error code: 400 - {'detail': "litellm.UnsupportedParamsError: watsonx does not support parameters: {'dimensions': 384}
|
"remote::watsonx", # openai.BadRequestError: Error code: 400 - {'detail': "litellm.UnsupportedParamsError: watsonx does not support parameters: {'dimensions': 384}
|
||||||
)
|
)
|
||||||
):
|
or (provider.provider_type == "remote::openai" and "text-embedding-3" not in model_id)
|
||||||
pytest.skip(
|
or (
|
||||||
f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."
|
provider.provider_type == "remote::nvidia"
|
||||||
|
and model_id
|
||||||
|
in [
|
||||||
|
"nvidia/nv-embedqa-e5-v5",
|
||||||
|
"nvidia/nv-embedqa-mistral-7b-v2",
|
||||||
|
"snowflake/arctic-embed-l",
|
||||||
|
]
|
||||||
)
|
)
|
||||||
if provider.provider_type == "remote::openai" and "text-embedding-3" not in model_id:
|
)
|
||||||
|
|
||||||
|
if should_skip:
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."
|
f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."
|
||||||
)
|
)
|
||||||
|
|
@ -105,6 +156,7 @@ def test_openai_embeddings_single_string(compat_client, client_with_models, embe
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_text,
|
input=input_text,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
|
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.object == "list"
|
assert response.object == "list"
|
||||||
|
|
@ -129,6 +181,7 @@ def test_openai_embeddings_multiple_strings(compat_client, client_with_models, e
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_texts,
|
input=input_texts,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
|
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.object == "list"
|
assert response.object == "list"
|
||||||
|
|
@ -155,6 +208,7 @@ def test_openai_embeddings_with_encoding_format_float(compat_client, client_with
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_text,
|
input=input_text,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
|
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.object == "list"
|
assert response.object == "list"
|
||||||
|
|
@ -175,6 +229,7 @@ def test_openai_embeddings_with_dimensions(compat_client, client_with_models, em
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_text,
|
input=input_text,
|
||||||
dimensions=dimensions,
|
dimensions=dimensions,
|
||||||
|
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.object == "list"
|
assert response.object == "list"
|
||||||
|
|
@ -196,6 +251,7 @@ def test_openai_embeddings_with_user_parameter(compat_client, client_with_models
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_text,
|
input=input_text,
|
||||||
user=user_id,
|
user=user_id,
|
||||||
|
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.object == "list"
|
assert response.object == "list"
|
||||||
|
|
@ -212,6 +268,7 @@ def test_openai_embeddings_empty_list_error(compat_client, client_with_models, e
|
||||||
compat_client.embeddings.create(
|
compat_client.embeddings.create(
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=[],
|
input=[],
|
||||||
|
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -223,6 +280,7 @@ def test_openai_embeddings_invalid_model_error(compat_client, client_with_models
|
||||||
compat_client.embeddings.create(
|
compat_client.embeddings.create(
|
||||||
model="invalid-model-id",
|
model="invalid-model-id",
|
||||||
input="Test text",
|
input="Test text",
|
||||||
|
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -233,16 +291,19 @@ def test_openai_embeddings_different_inputs_different_outputs(compat_client, cli
|
||||||
input_text1 = "This is the first text"
|
input_text1 = "This is the first text"
|
||||||
input_text2 = "This is completely different content"
|
input_text2 = "This is completely different content"
|
||||||
|
|
||||||
|
extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
|
||||||
response1 = compat_client.embeddings.create(
|
response1 = compat_client.embeddings.create(
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_text1,
|
input=input_text1,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
response2 = compat_client.embeddings.create(
|
response2 = compat_client.embeddings.create(
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_text2,
|
input=input_text2,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding1 = response1.data[0].embedding
|
embedding1 = response1.data[0].embedding
|
||||||
|
|
@ -267,6 +328,7 @@ def test_openai_embeddings_with_encoding_format_base64(compat_client, client_wit
|
||||||
input=input_text,
|
input=input_text,
|
||||||
encoding_format="base64",
|
encoding_format="base64",
|
||||||
dimensions=dimensions,
|
dimensions=dimensions,
|
||||||
|
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate response structure
|
# Validate response structure
|
||||||
|
|
@ -298,6 +360,7 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_texts,
|
input=input_texts,
|
||||||
encoding_format="base64",
|
encoding_format="base64",
|
||||||
|
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
|
||||||
)
|
)
|
||||||
# Validate response structure
|
# Validate response structure
|
||||||
assert response.object == "list"
|
assert response.object == "list"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue