fix: fix the error type in embedding test case (#3197)

# What does this PR do?
Currently the embedding integration test cases fail due to a
misalignment in the error type. This PR fixes the embedding integration
test by fixing the error type.

## Test Plan

```
pytest -s -v tests/integration/inference/test_embedding.py --stack-config="inference=nvidia" --embedding-model="nvidia/llama-3.2-nv-embedqa-1b-v2" --env NVIDIA_API_KEY={nvidia_api_key} --env NVIDIA_BASE_URL="https://integrate.api.nvidia.com"
```
This commit is contained in:
Jiayi Ni 2025-08-21 16:19:51 -07:00 committed by GitHub
parent 864610ca5c
commit deffaa9e4e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 20 additions and 13 deletions

View file

@ -7,7 +7,7 @@
import warnings import warnings
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from openai import NOT_GIVEN, APIConnectionError, BadRequestError from openai import NOT_GIVEN, APIConnectionError
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -197,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
} }
extra_body["input_type"] = task_type_options[task_type] extra_body["input_type"] = task_type_options[task_type]
try: response = await self.client.embeddings.create(
response = await self.client.embeddings.create( model=provider_model_id,
model=provider_model_id, input=input,
input=input, extra_body=extra_body,
extra_body=extra_body, )
)
except BadRequestError as e:
raise ValueError(f"Failed to get embeddings: {e}") from e
# #
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...) # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
# -> # ->

View file

@ -55,7 +55,7 @@
# #
import pytest import pytest
from llama_stack_client import BadRequestError from llama_stack_client import BadRequestError as LlamaStackBadRequestError
from llama_stack_client.types import EmbeddingsResponse from llama_stack_client.types import EmbeddingsResponse
from llama_stack_client.types.shared.interleaved_content import ( from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem, ImageContentItem,
@ -63,6 +63,9 @@ from llama_stack_client.types.shared.interleaved_content import (
ImageContentItemImageURL, ImageContentItemImageURL,
TextContentItem, TextContentItem,
) )
from openai import BadRequestError as OpenAIBadRequestError
from llama_stack.core.library_client import LlamaStackAsLibraryClient
DUMMY_STRING = "hello" DUMMY_STRING = "hello"
DUMMY_STRING2 = "world" DUMMY_STRING2 = "world"
@ -203,7 +206,14 @@ def test_embedding_truncation_error(
): ):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError): # Using LlamaStackClient from llama_stack_client will raise llama_stack_client.BadRequestError
# While using LlamaStackAsLibraryClient from llama_stack.distribution.library_client will raise the error that the backend raises
error_type = (
OpenAIBadRequestError
if isinstance(llama_stack_client, LlamaStackAsLibraryClient)
else LlamaStackBadRequestError
)
with pytest.raises(error_type):
llama_stack_client.inference.embeddings( llama_stack_client.inference.embeddings(
model_id=embedding_model_id, model_id=embedding_model_id,
contents=[DUMMY_LONG_TEXT], contents=[DUMMY_LONG_TEXT],
@ -283,7 +293,8 @@ def test_embedding_text_truncation_error(
): ):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError): error_type = ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
with pytest.raises(error_type):
llama_stack_client.inference.embeddings( llama_stack_client.inference.embeddings(
model_id=embedding_model_id, model_id=embedding_model_id,
contents=[DUMMY_STRING], contents=[DUMMY_STRING],