handle error and update tests for new client

This commit is contained in:
Matthew Farrellee 2025-02-21 16:38:39 -06:00
parent e4037d03c2
commit b54d896ce7
2 changed files with 19 additions and 20 deletions

View file

@ -8,7 +8,7 @@ import logging
import warnings
from typing import AsyncIterator, List, Optional, Union
from openai import APIConnectionError, AsyncOpenAI
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -170,11 +170,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
raise ValueError(f"Invalid task_type: {task_type}")
extra_body["input_type"] = task_type_options[task_type]
response = await self._client.embeddings.create(
model=model,
input=input,
extra_body=extra_body,
)
try:
response = await self._client.embeddings.create(
model=model,
input=input,
extra_body=extra_body,
)
except BadRequestError as e:
raise ValueError(f"Failed to get embeddings: {e}") from e
#
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...)

View file

@ -55,6 +55,7 @@
#
import pytest
from llama_stack_client import BadRequestError
from llama_stack_client.types import EmbeddingsResponse
from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem,
@ -63,8 +64,6 @@ from llama_stack_client.types.shared.interleaved_content import (
TextContentItem,
)
from llama_stack.apis.inference import EmbeddingTaskType, TextTruncation
DUMMY_STRING = "hello"
DUMMY_STRING2 = "world"
DUMMY_LONG_STRING = "NVDA " * 10240
@ -120,8 +119,8 @@ def test_embedding_image(llama_stack_client, embedding_model_id, contents):
@pytest.mark.parametrize(
"text_truncation",
[
TextTruncation.end,
TextTruncation.start,
"end",
"start",
],
)
@pytest.mark.parametrize(
@ -149,7 +148,7 @@ def test_embedding_truncation(llama_stack_client, embedding_model_id, text_trunc
"text_truncation",
[
None,
TextTruncation.none,
"none",
],
)
@pytest.mark.parametrize(
@ -164,13 +163,10 @@ def test_embedding_truncation(llama_stack_client, embedding_model_id, text_trunc
],
)
def test_embedding_truncation_error(llama_stack_client, embedding_model_id, text_truncation, contents):
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1
assert isinstance(response.embeddings[0], list)
assert isinstance(response.embeddings[0][0], float)
with pytest.raises(BadRequestError) as excinfo:
llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation
)
@pytest.mark.xfail(reason="Only valid for model supporting dimension reduction")
@ -186,9 +182,9 @@ def test_embedding_output_dimension(llama_stack_client, embedding_model_id):
@pytest.mark.xfail(reason="Only valid for model supporting task type")
def test_embedding_task_type(llama_stack_client, embedding_model_id):
query_embedding = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type=EmbeddingTaskType.query
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query"
)
document_embedding = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type=EmbeddingTaskType.document
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="doument"
)
assert query_embedding.embeddings != document_embedding.embeddings