mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-22 17:53:55 +00:00
fix: fix the error type in embedding test case (#3197)
# What does this PR do? Currently the embedding integration test cases fail due to a misalignment in the error type. This PR fixes the embedding integration test by fixing the error type. ## Test Plan ``` pytest -s -v tests/integration/inference/test_embedding.py --stack-config="inference=nvidia" --embedding-model="nvidia/llama-3.2-nv-embedqa-1b-v2" --env NVIDIA_API_KEY={nvidia_api_key} --env NVIDIA_BASE_URL="https://integrate.api.nvidia.com" ```
This commit is contained in:
parent
864610ca5c
commit
deffaa9e4e
2 changed files with 20 additions and 13 deletions
|
@ -7,7 +7,7 @@
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
from openai import NOT_GIVEN, APIConnectionError, BadRequestError
|
from openai import NOT_GIVEN, APIConnectionError
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -197,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||||
}
|
}
|
||||||
extra_body["input_type"] = task_type_options[task_type]
|
extra_body["input_type"] = task_type_options[task_type]
|
||||||
|
|
||||||
try:
|
response = await self.client.embeddings.create(
|
||||||
response = await self.client.embeddings.create(
|
model=provider_model_id,
|
||||||
model=provider_model_id,
|
input=input,
|
||||||
input=input,
|
extra_body=extra_body,
|
||||||
extra_body=extra_body,
|
)
|
||||||
)
|
|
||||||
except BadRequestError as e:
|
|
||||||
raise ValueError(f"Failed to get embeddings: {e}") from e
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
|
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
|
||||||
# ->
|
# ->
|
||||||
|
|
|
@ -55,7 +55,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client import BadRequestError
|
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||||
from llama_stack_client.types import EmbeddingsResponse
|
from llama_stack_client.types import EmbeddingsResponse
|
||||||
from llama_stack_client.types.shared.interleaved_content import (
|
from llama_stack_client.types.shared.interleaved_content import (
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
|
@ -63,6 +63,9 @@ from llama_stack_client.types.shared.interleaved_content import (
|
||||||
ImageContentItemImageURL,
|
ImageContentItemImageURL,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
|
from openai import BadRequestError as OpenAIBadRequestError
|
||||||
|
|
||||||
|
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
DUMMY_STRING = "hello"
|
DUMMY_STRING = "hello"
|
||||||
DUMMY_STRING2 = "world"
|
DUMMY_STRING2 = "world"
|
||||||
|
@ -203,7 +206,14 @@ def test_embedding_truncation_error(
|
||||||
):
|
):
|
||||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||||
with pytest.raises(BadRequestError):
|
# Using LlamaStackClient from llama_stack_client will raise llama_stack_client.BadRequestError
|
||||||
|
# While using LlamaStackAsLibraryClient from llama_stack.distribution.library_client will raise the error that the backend raises
|
||||||
|
error_type = (
|
||||||
|
OpenAIBadRequestError
|
||||||
|
if isinstance(llama_stack_client, LlamaStackAsLibraryClient)
|
||||||
|
else LlamaStackBadRequestError
|
||||||
|
)
|
||||||
|
with pytest.raises(error_type):
|
||||||
llama_stack_client.inference.embeddings(
|
llama_stack_client.inference.embeddings(
|
||||||
model_id=embedding_model_id,
|
model_id=embedding_model_id,
|
||||||
contents=[DUMMY_LONG_TEXT],
|
contents=[DUMMY_LONG_TEXT],
|
||||||
|
@ -283,7 +293,8 @@ def test_embedding_text_truncation_error(
|
||||||
):
|
):
|
||||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||||
with pytest.raises(BadRequestError):
|
error_type = ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||||
|
with pytest.raises(error_type):
|
||||||
llama_stack_client.inference.embeddings(
|
llama_stack_client.inference.embeddings(
|
||||||
model_id=embedding_model_id,
|
model_id=embedding_model_id,
|
||||||
contents=[DUMMY_STRING],
|
contents=[DUMMY_STRING],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue