add NVIDIA Inference embedding provider and tests

This commit is contained in:
Matthew Farrellee 2025-02-03 10:52:41 -05:00 committed by Ashwin Bharambe
parent 2608b6074f
commit 8706b311ba
5 changed files with 194 additions and 3 deletions

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
@ -46,6 +48,14 @@ _MODEL_ENTRIES = [
"meta/llama-3.2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
ProviderModelEntry(
provider_model_id="baai/bge-m3",
model_type=ModelType.embedding,
metadata={
"embedding_dimensions": 1024,
"context_length": 8192,
},
),
# TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
]

View file

@ -10,6 +10,11 @@ from typing import AsyncIterator, List, Optional, Union
from openai import APIConnectionError, AsyncOpenAI
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -19,7 +24,6 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
@ -117,9 +121,38 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()
if any(content_has_media(content) for content in contents):
raise NotImplementedError("Media is not supported")
#
# Llama Stack: contents = List[str] | List[InterleavedContentItem]
# ->
# OpenAI: input = str | List[str]
#
# we can ignore str and always pass List[str] to OpenAI
#
flat_contents = [
item.text if isinstance(item, TextContentItem) else item
for content in contents
for item in (content if isinstance(content, list) else [content])
]
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
model = self.get_provider_model_id(model_id)
response = await self._client.embeddings.create(
model=model,
input=input,
# extra_body={"input_type": "passage"|"query"}, # TODO(mf): how to tell caller's intent?
)
#
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...)
# ->
# Llama Stack: EmbeddingsResponse(embeddings=List[List[float]])
#
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
async def chat_completion(
self,

View file

@ -135,6 +135,13 @@ models:
provider_id: nvidia
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {
embedding_dimension: 1024
}
model_id: baai/bge-m3
provider_id: nvidia
provider_model_id: baai/bge-m3
model_type: embedding
shields: []
vector_dbs: []
datasets: []

View file

@ -58,6 +58,12 @@ def pytest_addoption(parser):
default="meta-llama/Llama-Guard-3-1B",
help="Specify the safety shield model to use for testing",
)
parser.addoption(
"--embedding-model",
action="store",
default=TEXT_MODEL,
help="Specify the embedding model to use for testing",
)
@pytest.fixture(scope="session")
@ -105,3 +111,9 @@ def pytest_generate_tests(metafunc):
[metafunc.config.getoption("--vision-inference-model")],
scope="session",
)
if "embedding_model_id" in metafunc.fixturenames:
metafunc.parametrize(
"embedding_model_id",
[metafunc.config.getoption("--embedding-model")],
scope="session",
)

View file

@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
#
# Test plan:
#
# Types of input:
# - array of a string
# - array of a image (ImageContentItem, either URL or base64 string)
# - array of a text (TextContentItem)
# - array of array of texts, images, or both
# Types of output:
# - list of list of floats
#
# Todo:
# - negative tests
# - empty
# - empty list
# - empty string
# - empty text
# - empty image
# - list of empty texts
# - list of empty images
# - list of empty texts and images
# - long
# - long string
# - long text
# - large image
# - appropriate combinations
# - batch size
# - many inputs
# - invalid
# - invalid URL
# - invalid base64
# - list of list of strings
#
# Notes:
# - use llama_stack_client fixture
# - use pytest.mark.parametrize when possible
# - no accuracy tests: only check the type of output, not the content
#
import pytest
from llama_stack_client.types import EmbeddingsResponse
from llama_stack_client.types.shared.interleaved_content import (
TextContentItem,
ImageContentItem,
ImageContentItemImage,
URL,
)
DUMMY_STRING = "hello"
DUMMY_STRING2 = "world"
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
# TODO(mf): add a real image URL and base64 string
DUMMY_IMAGE_URL = ImageContentItem(
image=ImageContentItemImage(url=URL(uri="https://example.com/image.jpg")), type="image"
)
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
@pytest.mark.parametrize(
"contents",
[
[DUMMY_STRING],
[DUMMY_TEXT],
[DUMMY_STRING, DUMMY_STRING2],
[DUMMY_TEXT, DUMMY_TEXT2],
[[DUMMY_TEXT]],
[[DUMMY_TEXT], [DUMMY_TEXT, DUMMY_TEXT2]],
[DUMMY_STRING, [DUMMY_TEXT, DUMMY_TEXT2]],
],
ids=[
"string",
"text",
"list[string]",
"list[text]",
"list[list[text]]",
"list[list[text],list[text,text]]",
"string,list[text,text]",
],
)
def test_embedding_text(llama_stack_client, embedding_model_id, contents):
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
assert isinstance(response.embeddings[0], list)
assert isinstance(response.embeddings[0][0], float)
@pytest.mark.parametrize(
"contents",
[
[DUMMY_IMAGE_URL],
[DUMMY_IMAGE_BASE64],
[DUMMY_IMAGE_URL, DUMMY_IMAGE_BASE64],
[DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT],
[[DUMMY_IMAGE_URL]],
[[DUMMY_IMAGE_BASE64]],
[[DUMMY_IMAGE_URL, DUMMY_TEXT, DUMMY_IMAGE_BASE64]],
[[DUMMY_IMAGE_URL], [DUMMY_IMAGE_BASE64]],
[[DUMMY_IMAGE_URL], [DUMMY_TEXT, DUMMY_IMAGE_BASE64]],
],
ids=[
"url",
"base64",
"list[url,base64]",
"list[url,string,base64,text]",
"list[list[url]]",
"list[list[base64]]",
"list[list[url,text,base64]]",
"list[list[url],list[base64]]",
"list[list[url],list[text,base64]]",
],
)
@pytest.mark.skip(reason="Media is not supported")
def test_embedding_image(llama_stack_client, embedding_model_id, contents):
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
assert isinstance(response.embeddings[0], list)
assert isinstance(response.embeddings[0][0], float)