handle error and update tests for new client

This commit is contained in:
Matthew Farrellee 2025-02-21 16:38:39 -06:00
parent e4037d03c2
commit b54d896ce7
2 changed files with 19 additions and 20 deletions

View file

@ -8,7 +8,7 @@ import logging
import warnings import warnings
from typing import AsyncIterator, List, Optional, Union from typing import AsyncIterator, List, Optional, Union
from openai import APIConnectionError, AsyncOpenAI from openai import APIConnectionError, AsyncOpenAI, BadRequestError
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -170,11 +170,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
raise ValueError(f"Invalid task_type: {task_type}") raise ValueError(f"Invalid task_type: {task_type}")
extra_body["input_type"] = task_type_options[task_type] extra_body["input_type"] = task_type_options[task_type]
response = await self._client.embeddings.create( try:
model=model, response = await self._client.embeddings.create(
input=input, model=model,
extra_body=extra_body, input=input,
) extra_body=extra_body,
)
except BadRequestError as e:
raise ValueError(f"Failed to get embeddings: {e}") from e
# #
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...) # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...)

View file

@ -55,6 +55,7 @@
# #
import pytest import pytest
from llama_stack_client import BadRequestError
from llama_stack_client.types import EmbeddingsResponse from llama_stack_client.types import EmbeddingsResponse
from llama_stack_client.types.shared.interleaved_content import ( from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem, ImageContentItem,
@ -63,8 +64,6 @@ from llama_stack_client.types.shared.interleaved_content import (
TextContentItem, TextContentItem,
) )
from llama_stack.apis.inference import EmbeddingTaskType, TextTruncation
DUMMY_STRING = "hello" DUMMY_STRING = "hello"
DUMMY_STRING2 = "world" DUMMY_STRING2 = "world"
DUMMY_LONG_STRING = "NVDA " * 10240 DUMMY_LONG_STRING = "NVDA " * 10240
@ -120,8 +119,8 @@ def test_embedding_image(llama_stack_client, embedding_model_id, contents):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"text_truncation", "text_truncation",
[ [
TextTruncation.end, "end",
TextTruncation.start, "start",
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -149,7 +148,7 @@ def test_embedding_truncation(llama_stack_client, embedding_model_id, text_trunc
"text_truncation", "text_truncation",
[ [
None, None,
TextTruncation.none, "none",
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -164,13 +163,10 @@ def test_embedding_truncation(llama_stack_client, embedding_model_id, text_trunc
], ],
) )
def test_embedding_truncation_error(llama_stack_client, embedding_model_id, text_truncation, contents): def test_embedding_truncation_error(llama_stack_client, embedding_model_id, text_truncation, contents):
response = llama_stack_client.inference.embeddings( with pytest.raises(BadRequestError) as excinfo:
model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation llama_stack_client.inference.embeddings(
) model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation
assert isinstance(response, EmbeddingsResponse) )
assert len(response.embeddings) == 1
assert isinstance(response.embeddings[0], list)
assert isinstance(response.embeddings[0][0], float)
@pytest.mark.xfail(reason="Only valid for model supporting dimension reduction") @pytest.mark.xfail(reason="Only valid for model supporting dimension reduction")
@ -186,9 +182,9 @@ def test_embedding_output_dimension(llama_stack_client, embedding_model_id):
@pytest.mark.xfail(reason="Only valid for model supporting task type") @pytest.mark.xfail(reason="Only valid for model supporting task type")
def test_embedding_task_type(llama_stack_client, embedding_model_id): def test_embedding_task_type(llama_stack_client, embedding_model_id):
query_embedding = llama_stack_client.inference.embeddings( query_embedding = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type=EmbeddingTaskType.query model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query"
) )
document_embedding = llama_stack_client.inference.embeddings( document_embedding = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type=EmbeddingTaskType.document model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="doument"
) )
assert query_embedding.embeddings != document_embedding.embeddings assert query_embedding.embeddings != document_embedding.embeddings