mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 20:02:37 +00:00
update
This commit is contained in:
parent
b58d8d8c90
commit
d3782f0e3b
2 changed files with 18 additions and 13 deletions
|
|
@ -7,7 +7,7 @@
|
|||
import warnings
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from openai import NOT_GIVEN, APIConnectionError, BadRequestError
|
||||
from openai import NOT_GIVEN, APIConnectionError
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
|
@ -197,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
|||
}
|
||||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
response = await self.client.embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
except BadRequestError as e:
|
||||
raise ValueError(f"Failed to get embeddings: {e}") from e
|
||||
|
||||
#
|
||||
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
|
||||
# ->
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@
|
|||
#
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||
from llama_stack_client.types import EmbeddingsResponse
|
||||
from llama_stack_client.types.shared.interleaved_content import (
|
||||
ImageContentItem,
|
||||
|
|
@ -62,6 +63,9 @@ from llama_stack_client.types.shared.interleaved_content import (
|
|||
ImageContentItemImageURL,
|
||||
TextContentItem,
|
||||
)
|
||||
from openai import BadRequestError as OpenAIBadRequestError
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
DUMMY_STRING = "hello"
|
||||
DUMMY_STRING2 = "world"
|
||||
|
|
@ -204,8 +208,12 @@ def test_embedding_truncation_error(
|
|||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
# Using LlamaStackClient from llama_stack_client will raise llama_stack_client.BadRequestError
|
||||
# While using LlamaStackAsLibraryClient from llama_stack.distribution.library_client will raise the error that the backend raises
|
||||
# Here we are using the LlamaStackAsLibraryClient, so the error raised is the same as what the backend raises
|
||||
with pytest.raises(ValueError):
|
||||
error_type = (
|
||||
OpenAIBadRequestError
|
||||
if isinstance(llama_stack_client, LlamaStackAsLibraryClient)
|
||||
else LlamaStackBadRequestError
|
||||
)
|
||||
with pytest.raises(error_type):
|
||||
llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id,
|
||||
contents=[DUMMY_LONG_TEXT],
|
||||
|
|
@ -285,7 +293,8 @@ def test_embedding_text_truncation_error(
|
|||
):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
with pytest.raises(ValueError):
|
||||
error_type = ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||
with pytest.raises(error_type):
|
||||
llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id,
|
||||
contents=[DUMMY_STRING],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue