feat: add nvidia embedding implementation for new signature, task_type, output_dimention, text_truncation (#1213)

# What does this PR do?

updates nvidia inference provider's embedding implementation to use new
signature

add support for task_type, output_dimensions, text_truncation parameters

## Test Plan

`LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v
tests/client-sdk/inference/test_embedding.py --embedding-model
baai/bge-m3`
This commit is contained in:
Matthew Farrellee 2025-02-27 18:58:11 -06:00 committed by GitHub
parent 73c6f6126f
commit e28cedd833
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 161 additions and 14 deletions

View file

@ -8,7 +8,7 @@ import logging
import warnings import warnings
from typing import AsyncIterator, List, Optional, Union from typing import AsyncIterator, List, Optional, Union
from openai import APIConnectionError, AsyncOpenAI from openai import APIConnectionError, AsyncOpenAI, BadRequestError
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -144,19 +144,38 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# #
# we can ignore str and always pass List[str] to OpenAI # we can ignore str and always pass List[str] to OpenAI
# #
flat_contents = [ flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
item.text if isinstance(item, TextContentItem) else item
for content in contents
for item in (content if isinstance(content, list) else [content])
]
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents] input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
model = self.get_provider_model_id(model_id) model = self.get_provider_model_id(model_id)
extra_body = {}
if text_truncation is not None:
text_truncation_options = {
TextTruncation.none: "NONE",
TextTruncation.end: "END",
TextTruncation.start: "START",
}
extra_body["truncate"] = text_truncation_options[text_truncation]
if output_dimension is not None:
extra_body["dimensions"] = output_dimension
if task_type is not None:
task_type_options = {
EmbeddingTaskType.document: "passage",
EmbeddingTaskType.query: "query",
}
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self._client.embeddings.create( response = await self._client.embeddings.create(
model=model, model=model,
input=input, input=input,
# extra_body={"input_type": "passage"|"query"}, # TODO(mf): how to tell caller's intent? extra_body=extra_body,
) )
except BadRequestError as e:
raise ValueError(f"Failed to get embeddings: {e}") from e
# #
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...) # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...)

View file

@ -14,6 +14,23 @@
# - array of a text (TextContentItem) # - array of a text (TextContentItem)
# Types of output: # Types of output:
# - list of list of floats # - list of list of floats
# Params:
# - text_truncation
# - absent w/ long text -> error
# - none w/ long text -> error
# - absent w/ short text -> ok
# - none w/ short text -> ok
# - end w/ long text -> ok
# - end w/ short text -> ok
# - start w/ long text -> ok
# - start w/ short text -> ok
# - output_dimension
# - response dimension matches
# - task_type, only for asymmetric models
# - query embedding != passage embedding
# Negative:
# - long string
# - long text
# #
# Todo: # Todo:
# - negative tests # - negative tests
@ -23,8 +40,6 @@
# - empty text # - empty text
# - empty image # - empty image
# - long # - long
# - long string
# - long text
# - large image # - large image
# - appropriate combinations # - appropriate combinations
# - batch size # - batch size
@ -40,6 +55,7 @@
# #
import pytest import pytest
from llama_stack_client import BadRequestError
from llama_stack_client.types import EmbeddingsResponse from llama_stack_client.types import EmbeddingsResponse
from llama_stack_client.types.shared.interleaved_content import ( from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem, ImageContentItem,
@ -50,8 +66,10 @@ from llama_stack_client.types.shared.interleaved_content import (
DUMMY_STRING = "hello" DUMMY_STRING = "hello"
DUMMY_STRING2 = "world" DUMMY_STRING2 = "world"
DUMMY_LONG_STRING = "NVDA " * 10240
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text") DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text") DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
DUMMY_LONG_TEXT = TextContentItem(text=DUMMY_LONG_STRING, type="text")
# TODO(mf): add a real image URL and base64 string # TODO(mf): add a real image URL and base64 string
DUMMY_IMAGE_URL = ImageContentItem( DUMMY_IMAGE_URL = ImageContentItem(
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image" image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
@ -89,10 +107,120 @@ def test_embedding_text(llama_stack_client, embedding_model_id, contents):
"list[url,string,base64,text]", "list[url,string,base64,text]",
], ],
) )
@pytest.mark.skip(reason="Media is not supported") @pytest.mark.xfail(reason="Media is not supported")
def test_embedding_image(llama_stack_client, embedding_model_id, contents): def test_embedding_image(llama_stack_client, embedding_model_id, contents):
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents)
assert isinstance(response, EmbeddingsResponse) assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
assert isinstance(response.embeddings[0], list) assert isinstance(response.embeddings[0], list)
assert isinstance(response.embeddings[0][0], float) assert isinstance(response.embeddings[0][0], float)
@pytest.mark.parametrize(
"text_truncation",
[
"end",
"start",
],
)
@pytest.mark.parametrize(
"contents",
[
[DUMMY_LONG_TEXT],
[DUMMY_STRING],
],
ids=[
"long",
"short",
],
)
def test_embedding_truncation(llama_stack_client, embedding_model_id, text_truncation, contents):
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, text_truncation=text_truncation
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1
assert isinstance(response.embeddings[0], list)
assert isinstance(response.embeddings[0][0], float)
@pytest.mark.parametrize(
"text_truncation",
[
None,
"none",
],
)
@pytest.mark.parametrize(
"contents",
[
[DUMMY_LONG_TEXT],
[DUMMY_LONG_STRING],
],
ids=[
"long-text",
"long-str",
],
)
def test_embedding_truncation_error(llama_stack_client, embedding_model_id, text_truncation, contents):
with pytest.raises(BadRequestError) as excinfo:
llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation
)
@pytest.mark.xfail(reason="Only valid for model supporting dimension reduction")
def test_embedding_output_dimension(llama_stack_client, embedding_model_id):
base_response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=[DUMMY_STRING])
test_response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], output_dimension=32
)
assert len(base_response.embeddings[0]) != len(test_response.embeddings[0])
assert len(test_response.embeddings[0]) == 32
@pytest.mark.xfail(reason="Only valid for model supporting task type")
def test_embedding_task_type(llama_stack_client, embedding_model_id):
query_embedding = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query"
)
document_embedding = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="document"
)
assert query_embedding.embeddings != document_embedding.embeddings
@pytest.mark.parametrize(
"text_truncation",
[
None,
"none",
"end",
"start",
],
)
def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_truncation):
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1
assert isinstance(response.embeddings[0], list)
assert isinstance(response.embeddings[0][0], float)
@pytest.mark.parametrize(
"text_truncation",
[
"NONE",
"END",
"START",
"left",
"right",
],
)
def test_embedding_text_truncation_error(llama_stack_client, embedding_model_id, text_truncation):
with pytest.raises(BadRequestError) as excinfo:
llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation
)