diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 9384928aa..29e11c39d 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -8,7 +8,7 @@ import logging import warnings from typing import AsyncIterator, List, Optional, Union -from openai import APIConnectionError, AsyncOpenAI +from openai import APIConnectionError, AsyncOpenAI, BadRequestError from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -170,11 +170,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): raise ValueError(f"Invalid task_type: {task_type}") extra_body["input_type"] = task_type_options[task_type] - response = await self._client.embeddings.create( - model=model, - input=input, - extra_body=extra_body, - ) + try: + response = await self._client.embeddings.create( + model=model, + input=input, + extra_body=extra_body, + ) + except BadRequestError as e: + raise ValueError(f"Failed to get embeddings: {e}") from e # # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...) diff --git a/tests/client-sdk/inference/test_embedding.py b/tests/client-sdk/inference/test_embedding.py index 2d168b525..6b60c8d9b 100644 --- a/tests/client-sdk/inference/test_embedding.py +++ b/tests/client-sdk/inference/test_embedding.py @@ -55,6 +55,7 @@ # import pytest +from llama_stack_client import BadRequestError from llama_stack_client.types import EmbeddingsResponse from llama_stack_client.types.shared.interleaved_content import ( ImageContentItem, @@ -63,8 +64,6 @@ from llama_stack_client.types.shared.interleaved_content import ( TextContentItem, ) -from llama_stack.apis.inference import EmbeddingTaskType, TextTruncation - DUMMY_STRING = "hello" DUMMY_STRING2 = "world" DUMMY_LONG_STRING = "NVDA " * 10240 @@ -120,8 +119,8 @@ def test_embedding_image(llama_stack_client, embedding_model_id, contents): @pytest.mark.parametrize( "text_truncation", [ - TextTruncation.end, - TextTruncation.start, + "end", + "start", ], ) @pytest.mark.parametrize( @@ -149,7 +148,7 @@ def test_embedding_truncation(llama_stack_client, embedding_model_id, text_trunc "text_truncation", [ None, - TextTruncation.none, + "none", ], ) @pytest.mark.parametrize( @@ -164,13 +163,10 @@ def test_embedding_truncation(llama_stack_client, embedding_model_id, text_trunc ], ) def test_embedding_truncation_error(llama_stack_client, embedding_model_id, text_truncation, contents): - response = llama_stack_client.inference.embeddings( - model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation - ) - assert isinstance(response, EmbeddingsResponse) - assert len(response.embeddings) == 1 - assert isinstance(response.embeddings[0], list) - assert isinstance(response.embeddings[0][0], float) + with pytest.raises(BadRequestError) as excinfo: + llama_stack_client.inference.embeddings( + model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation + ) @pytest.mark.xfail(reason="Only valid for model supporting dimension reduction") @@ -186,9 +182,9 @@ def test_embedding_output_dimension(llama_stack_client, embedding_model_id): @pytest.mark.xfail(reason="Only valid for model supporting task type") def test_embedding_task_type(llama_stack_client, embedding_model_id): query_embedding = llama_stack_client.inference.embeddings( - model_id=embedding_model_id, contents=[DUMMY_STRING], task_type=EmbeddingTaskType.query + model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query" ) document_embedding = llama_stack_client.inference.embeddings( - model_id=embedding_model_id, contents=[DUMMY_STRING], task_type=EmbeddingTaskType.document + model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="doument" ) assert query_embedding.embeddings != document_embedding.embeddings