From 832c535aafacab758a244a963e5a384f8b16c018 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 20 Feb 2025 18:59:48 -0600 Subject: [PATCH] feat(providers): add NVIDIA Inference embedding provider and tests (#935) # What does this PR do? add /v1/inference/embeddings implementation to NVIDIA provider **open topics** - - *asymmetric models*. NeMo Retriever includes asymmetric models, which are models that embed differently depending on if the input is destined for storage or lookup against storage. the /v1/inference/embeddings api does not allow the user to indicate the type of embedding to perform. see https://github.com/meta-llama/llama-stack/issues/934 - *truncation*. embedding models typically have a limited context window, e.g. 1024 tokens is common though newer models have 8k windows. when the input is larger than this window the endpoint cannot perform its designed function. two options: 0. return an error so the user can reduce the input size and retry; 1. perform truncation for the user and proceed (common strategies are left or right truncation). many users encounter context window size limits and will struggle to write reliable programs. this struggle is especially acute without access to the model's tokenizer. the /v1/inference/embeddings api does not allow the user to delegate truncation policy. see https://github.com/meta-llama/llama-stack/issues/933 - *dimensions*. "Matryoshka" embedding models are available. they allow users to control the number of embedding dimensions the model produces. this is a critical feature for managing storage constraints. embeddings of 1024 dimensions what achieve 95% recall for an application may not be worth the storage cost if a 512 dimensions can achieve 93% recall. controlling embedding dimensions allows applications to determine their recall and storage tradeoffs. the /v1/inference/embeddings api does not allow the user to control the output dimensions. see https://github.com/meta-llama/llama-stack/issues/932 ## Test Plan - `llama stack run llama_stack/templates/nvidia/run.yaml` - `LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/inference/test_embedding.py --embedding-model baai/bge-m3` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [x] Wrote necessary unit or integration tests. --------- Co-authored-by: Ashwin Bharambe --- .../remote_hosted_distro/nvidia.md | 1 + .../remote/inference/nvidia/models.py | 10 ++ .../remote/inference/nvidia/nvidia.py | 39 ++++++- llama_stack/templates/nvidia/nvidia.py | 4 +- llama_stack/templates/nvidia/run.yaml | 7 ++ tests/client-sdk/conftest.py | 12 ++ tests/client-sdk/inference/test_embedding.py | 103 ++++++++++++++++++ 7 files changed, 172 insertions(+), 4 deletions(-) create mode 100644 tests/client-sdk/inference/test_embedding.py diff --git a/docs/source/distributions/remote_hosted_distro/nvidia.md b/docs/source/distributions/remote_hosted_distro/nvidia.md index f352f737e..a1f70e450 100644 --- a/docs/source/distributions/remote_hosted_distro/nvidia.md +++ b/docs/source/distributions/remote_hosted_distro/nvidia.md @@ -36,6 +36,7 @@ The following models are available by default: - `meta-llama/Llama-3.2-3B-Instruct (meta/llama-3.2-3b-instruct)` - `meta-llama/Llama-3.2-11B-Vision-Instruct (meta/llama-3.2-11b-vision-instruct)` - `meta-llama/Llama-3.2-90B-Vision-Instruct (meta/llama-3.2-90b-vision-instruct)` +- `baai/bge-m3 (baai/bge-m3)` ### Prerequisite: API Keys diff --git a/llama_stack/providers/remote/inference/nvidia/models.py b/llama_stack/providers/remote/inference/nvidia/models.py index c432861ee..fa9944be1 100644 --- a/llama_stack/providers/remote/inference/nvidia/models.py +++ b/llama_stack/providers/remote/inference/nvidia/models.py @@ -4,8 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.models import ModelType from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, build_hf_repo_model_entry, ) @@ -46,6 +48,14 @@ _MODEL_ENTRIES = [ "meta/llama-3.2-90b-vision-instruct", CoreModelId.llama3_2_90b_vision_instruct.value, ), + ProviderModelEntry( + provider_model_id="baai/bge-m3", + model_type=ModelType.embedding, + metadata={ + "embedding_dimensions": 1024, + "context_length": 8192, + }, + ), # TODO(mf): how do we handle Nemotron models? # "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct", ] diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 824389577..6f38230b2 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -10,6 +10,11 @@ from typing import AsyncIterator, List, Optional, Union from openai import APIConnectionError, AsyncOpenAI +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, + TextContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -19,7 +24,6 @@ from llama_stack.apis.inference import ( CompletionResponseStreamChunk, EmbeddingsResponse, Inference, - InterleavedContent, LogProbConfig, Message, ResponseFormat, @@ -117,9 +121,38 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: - raise NotImplementedError() + if any(content_has_media(content) for content in contents): + raise NotImplementedError("Media is not supported") + + # + # Llama Stack: contents = List[str] | List[InterleavedContentItem] + # -> + # OpenAI: input = str | List[str] + # + # we can ignore str and always pass List[str] to OpenAI + # + flat_contents = [ + item.text if isinstance(item, TextContentItem) else item + for content in contents + for item in (content if isinstance(content, list) else [content]) + ] + input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents] + model = self.get_provider_model_id(model_id) + + response = await self._client.embeddings.create( + model=model, + input=input, + # extra_body={"input_type": "passage"|"query"}, # TODO(mf): how to tell caller's intent? + ) + + # + # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...) + # -> + # Llama Stack: EmbeddingsResponse(embeddings=List[List[float]]) + # + return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) async def chat_completion( self, diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index a505a1b93..56d13a09a 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -41,9 +41,11 @@ def get_distribution_template() -> DistributionTemplate: core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} default_models = [ ModelInput( - model_id=core_model_to_hf_repo[m.llama_model], + model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id, provider_model_id=m.provider_model_id, provider_id="nvidia", + model_type=m.model_type, + metadata=m.metadata, ) for m in _MODEL_ENTRIES ] diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index 14fb28354..891fd112a 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -135,6 +135,13 @@ models: provider_id: nvidia provider_model_id: meta/llama-3.2-90b-vision-instruct model_type: llm +- metadata: + embedding_dimensions: 1024 + context_length: 8192 + model_id: baai/bge-m3 + provider_id: nvidia + provider_model_id: baai/bge-m3 + model_type: embedding shields: [] vector_dbs: [] datasets: [] diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index b397f7ab3..efdec6b01 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -58,6 +58,12 @@ def pytest_addoption(parser): default="meta-llama/Llama-Guard-3-1B", help="Specify the safety shield model to use for testing", ) + parser.addoption( + "--embedding-model", + action="store", + default=TEXT_MODEL, + help="Specify the embedding model to use for testing", + ) @pytest.fixture(scope="session") @@ -105,3 +111,9 @@ def pytest_generate_tests(metafunc): [metafunc.config.getoption("--vision-inference-model")], scope="session", ) + if "embedding_model_id" in metafunc.fixturenames: + metafunc.parametrize( + "embedding_model_id", + [metafunc.config.getoption("--embedding-model")], + scope="session", + ) diff --git a/tests/client-sdk/inference/test_embedding.py b/tests/client-sdk/inference/test_embedding.py new file mode 100644 index 000000000..a25382866 --- /dev/null +++ b/tests/client-sdk/inference/test_embedding.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +# +# Test plan: +# +# Types of input: +# - array of a string +# - array of a image (ImageContentItem, either URL or base64 string) +# - array of a text (TextContentItem) +# - array of array of texts, images, or both +# Types of output: +# - list of list of floats +# +# Todo: +# - negative tests +# - empty +# - empty list +# - empty string +# - empty text +# - empty image +# - list of empty texts +# - list of empty images +# - list of empty texts and images +# - long +# - long string +# - long text +# - large image +# - appropriate combinations +# - batch size +# - many inputs +# - invalid +# - invalid URL +# - invalid base64 +# - list of list of strings +# +# Notes: +# - use llama_stack_client fixture +# - use pytest.mark.parametrize when possible +# - no accuracy tests: only check the type of output, not the content +# + +import pytest +from llama_stack_client.types import EmbeddingsResponse +from llama_stack_client.types.shared.interleaved_content import ( + URL, + ImageContentItem, + ImageContentItemImage, + TextContentItem, +) + +DUMMY_STRING = "hello" +DUMMY_STRING2 = "world" +DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text") +DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text") +# TODO(mf): add a real image URL and base64 string +DUMMY_IMAGE_URL = ImageContentItem( + image=ImageContentItemImage(url=URL(uri="https://example.com/image.jpg")), type="image" +) +DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image") + + +@pytest.mark.parametrize( + "contents", + [ + [DUMMY_STRING, DUMMY_STRING2], + [DUMMY_TEXT, DUMMY_TEXT2], + ], + ids=[ + "list[string]", + "list[text]", + ], +) +def test_embedding_text(llama_stack_client, embedding_model_id, contents): + response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) + assert isinstance(response, EmbeddingsResponse) + assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) + assert isinstance(response.embeddings[0], list) + assert isinstance(response.embeddings[0][0], float) + + +@pytest.mark.parametrize( + "contents", + [ + [DUMMY_IMAGE_URL, DUMMY_IMAGE_BASE64], + [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT], + ], + ids=[ + "list[url,base64]", + "list[url,string,base64,text]", + ], +) +@pytest.mark.skip(reason="Media is not supported") +def test_embedding_image(llama_stack_client, embedding_model_id, contents): + response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) + assert isinstance(response, EmbeddingsResponse) + assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) + assert isinstance(response.embeddings[0], list) + assert isinstance(response.embeddings[0][0], float)