mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
add NVIDIA Inference embedding provider and tests
This commit is contained in:
parent
2608b6074f
commit
8706b311ba
5 changed files with 194 additions and 3 deletions
|
@ -4,8 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
|
@ -46,6 +48,14 @@ _MODEL_ENTRIES = [
|
|||
"meta/llama-3.2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="baai/bge-m3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimensions": 1024,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
# TODO(mf): how do we handle Nemotron models?
|
||||
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
||||
]
|
||||
|
|
|
@ -10,6 +10,11 @@ from typing import AsyncIterator, List, Optional, Union
|
|||
|
||||
from openai import APIConnectionError, AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -19,7 +24,6 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
|
@ -117,9 +121,38 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
if any(content_has_media(content) for content in contents):
|
||||
raise NotImplementedError("Media is not supported")
|
||||
|
||||
#
|
||||
# Llama Stack: contents = List[str] | List[InterleavedContentItem]
|
||||
# ->
|
||||
# OpenAI: input = str | List[str]
|
||||
#
|
||||
# we can ignore str and always pass List[str] to OpenAI
|
||||
#
|
||||
flat_contents = [
|
||||
item.text if isinstance(item, TextContentItem) else item
|
||||
for content in contents
|
||||
for item in (content if isinstance(content, list) else [content])
|
||||
]
|
||||
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
||||
model = self.get_provider_model_id(model_id)
|
||||
|
||||
response = await self._client.embeddings.create(
|
||||
model=model,
|
||||
input=input,
|
||||
# extra_body={"input_type": "passage"|"query"}, # TODO(mf): how to tell caller's intent?
|
||||
)
|
||||
|
||||
#
|
||||
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...)
|
||||
# ->
|
||||
# Llama Stack: EmbeddingsResponse(embeddings=List[List[float]])
|
||||
#
|
||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
|
|
@ -135,6 +135,13 @@ models:
|
|||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {
|
||||
embedding_dimension: 1024
|
||||
}
|
||||
model_id: baai/bge-m3
|
||||
provider_id: nvidia
|
||||
provider_model_id: baai/bge-m3
|
||||
model_type: embedding
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
|
|
|
@ -58,6 +58,12 @@ def pytest_addoption(parser):
|
|||
default="meta-llama/Llama-Guard-3-1B",
|
||||
help="Specify the safety shield model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--embedding-model",
|
||||
action="store",
|
||||
default=TEXT_MODEL,
|
||||
help="Specify the embedding model to use for testing",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -105,3 +111,9 @@ def pytest_generate_tests(metafunc):
|
|||
[metafunc.config.getoption("--vision-inference-model")],
|
||||
scope="session",
|
||||
)
|
||||
if "embedding_model_id" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"embedding_model_id",
|
||||
[metafunc.config.getoption("--embedding-model")],
|
||||
scope="session",
|
||||
)
|
||||
|
|
129
tests/client-sdk/inference/test_embedding.py
Normal file
129
tests/client-sdk/inference/test_embedding.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
#
|
||||
# Test plan:
|
||||
#
|
||||
# Types of input:
|
||||
# - array of a string
|
||||
# - array of a image (ImageContentItem, either URL or base64 string)
|
||||
# - array of a text (TextContentItem)
|
||||
# - array of array of texts, images, or both
|
||||
# Types of output:
|
||||
# - list of list of floats
|
||||
#
|
||||
# Todo:
|
||||
# - negative tests
|
||||
# - empty
|
||||
# - empty list
|
||||
# - empty string
|
||||
# - empty text
|
||||
# - empty image
|
||||
# - list of empty texts
|
||||
# - list of empty images
|
||||
# - list of empty texts and images
|
||||
# - long
|
||||
# - long string
|
||||
# - long text
|
||||
# - large image
|
||||
# - appropriate combinations
|
||||
# - batch size
|
||||
# - many inputs
|
||||
# - invalid
|
||||
# - invalid URL
|
||||
# - invalid base64
|
||||
# - list of list of strings
|
||||
#
|
||||
# Notes:
|
||||
# - use llama_stack_client fixture
|
||||
# - use pytest.mark.parametrize when possible
|
||||
# - no accuracy tests: only check the type of output, not the content
|
||||
#
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack_client.types import EmbeddingsResponse
|
||||
from llama_stack_client.types.shared.interleaved_content import (
|
||||
TextContentItem,
|
||||
ImageContentItem,
|
||||
ImageContentItemImage,
|
||||
URL,
|
||||
)
|
||||
|
||||
|
||||
DUMMY_STRING = "hello"
|
||||
DUMMY_STRING2 = "world"
|
||||
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
|
||||
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
|
||||
# TODO(mf): add a real image URL and base64 string
|
||||
DUMMY_IMAGE_URL = ImageContentItem(
|
||||
image=ImageContentItemImage(url=URL(uri="https://example.com/image.jpg")), type="image"
|
||||
)
|
||||
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"contents",
|
||||
[
|
||||
[DUMMY_STRING],
|
||||
[DUMMY_TEXT],
|
||||
[DUMMY_STRING, DUMMY_STRING2],
|
||||
[DUMMY_TEXT, DUMMY_TEXT2],
|
||||
[[DUMMY_TEXT]],
|
||||
[[DUMMY_TEXT], [DUMMY_TEXT, DUMMY_TEXT2]],
|
||||
[DUMMY_STRING, [DUMMY_TEXT, DUMMY_TEXT2]],
|
||||
],
|
||||
ids=[
|
||||
"string",
|
||||
"text",
|
||||
"list[string]",
|
||||
"list[text]",
|
||||
"list[list[text]]",
|
||||
"list[list[text],list[text,text]]",
|
||||
"string,list[text,text]",
|
||||
],
|
||||
)
|
||||
def test_embedding_text(llama_stack_client, embedding_model_id, contents):
|
||||
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents)
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
|
||||
assert isinstance(response.embeddings[0], list)
|
||||
assert isinstance(response.embeddings[0][0], float)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"contents",
|
||||
[
|
||||
[DUMMY_IMAGE_URL],
|
||||
[DUMMY_IMAGE_BASE64],
|
||||
[DUMMY_IMAGE_URL, DUMMY_IMAGE_BASE64],
|
||||
[DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT],
|
||||
[[DUMMY_IMAGE_URL]],
|
||||
[[DUMMY_IMAGE_BASE64]],
|
||||
[[DUMMY_IMAGE_URL, DUMMY_TEXT, DUMMY_IMAGE_BASE64]],
|
||||
[[DUMMY_IMAGE_URL], [DUMMY_IMAGE_BASE64]],
|
||||
[[DUMMY_IMAGE_URL], [DUMMY_TEXT, DUMMY_IMAGE_BASE64]],
|
||||
],
|
||||
ids=[
|
||||
"url",
|
||||
"base64",
|
||||
"list[url,base64]",
|
||||
"list[url,string,base64,text]",
|
||||
"list[list[url]]",
|
||||
"list[list[base64]]",
|
||||
"list[list[url,text,base64]]",
|
||||
"list[list[url],list[base64]]",
|
||||
"list[list[url],list[text,base64]]",
|
||||
],
|
||||
)
|
||||
@pytest.mark.skip(reason="Media is not supported")
|
||||
def test_embedding_image(llama_stack_client, embedding_model_id, contents):
|
||||
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents)
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
|
||||
assert isinstance(response.embeddings[0], list)
|
||||
assert isinstance(response.embeddings[0][0], float)
|
Loading…
Add table
Add a link
Reference in a new issue