diff --git a/docs/source/distributions/remote_hosted_distro/nvidia.md b/docs/source/distributions/remote_hosted_distro/nvidia.md index f352f737e..a1f70e450 100644 --- a/docs/source/distributions/remote_hosted_distro/nvidia.md +++ b/docs/source/distributions/remote_hosted_distro/nvidia.md @@ -36,6 +36,7 @@ The following models are available by default: - `meta-llama/Llama-3.2-3B-Instruct (meta/llama-3.2-3b-instruct)` - `meta-llama/Llama-3.2-11B-Vision-Instruct (meta/llama-3.2-11b-vision-instruct)` - `meta-llama/Llama-3.2-90B-Vision-Instruct (meta/llama-3.2-90b-vision-instruct)` +- `baai/bge-m3 (baai/bge-m3)` ### Prerequisite: API Keys diff --git a/llama_stack/providers/remote/inference/nvidia/models.py b/llama_stack/providers/remote/inference/nvidia/models.py index c432861ee..fa9944be1 100644 --- a/llama_stack/providers/remote/inference/nvidia/models.py +++ b/llama_stack/providers/remote/inference/nvidia/models.py @@ -4,8 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.models import ModelType from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, build_hf_repo_model_entry, ) @@ -46,6 +48,14 @@ _MODEL_ENTRIES = [ "meta/llama-3.2-90b-vision-instruct", CoreModelId.llama3_2_90b_vision_instruct.value, ), + ProviderModelEntry( + provider_model_id="baai/bge-m3", + model_type=ModelType.embedding, + metadata={ + "embedding_dimensions": 1024, + "context_length": 8192, + }, + ), # TODO(mf): how do we handle Nemotron models? # "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct", ] diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 824389577..6f38230b2 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -10,6 +10,11 @@ from typing import AsyncIterator, List, Optional, Union from openai import APIConnectionError, AsyncOpenAI +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, + TextContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -19,7 +24,6 @@ from llama_stack.apis.inference import ( CompletionResponseStreamChunk, EmbeddingsResponse, Inference, - InterleavedContent, LogProbConfig, Message, ResponseFormat, @@ -117,9 +121,38 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: - raise NotImplementedError() + if any(content_has_media(content) for content in contents): + raise NotImplementedError("Media is not supported") + + # + # Llama Stack: contents = List[str] | List[InterleavedContentItem] + # -> + # OpenAI: input = str | List[str] + # + # we can ignore str and always pass List[str] to OpenAI + # + flat_contents = [ + item.text if isinstance(item, TextContentItem) else item + for content in contents + for item in (content if isinstance(content, list) else [content]) + ] + input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents] + model = self.get_provider_model_id(model_id) + + response = await self._client.embeddings.create( + model=model, + input=input, + # extra_body={"input_type": "passage"|"query"}, # TODO(mf): how to tell caller's intent? + ) + + # + # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...) + # -> + # Llama Stack: EmbeddingsResponse(embeddings=List[List[float]]) + # + return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) async def chat_completion( self, diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index a505a1b93..56d13a09a 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -41,9 +41,11 @@ def get_distribution_template() -> DistributionTemplate: core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} default_models = [ ModelInput( - model_id=core_model_to_hf_repo[m.llama_model], + model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id, provider_model_id=m.provider_model_id, provider_id="nvidia", + model_type=m.model_type, + metadata=m.metadata, ) for m in _MODEL_ENTRIES ] diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index 14fb28354..891fd112a 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -135,6 +135,13 @@ models: provider_id: nvidia provider_model_id: meta/llama-3.2-90b-vision-instruct model_type: llm +- metadata: + embedding_dimensions: 1024 + context_length: 8192 + model_id: baai/bge-m3 + provider_id: nvidia + provider_model_id: baai/bge-m3 + model_type: embedding shields: [] vector_dbs: [] datasets: [] diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index b397f7ab3..efdec6b01 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -58,6 +58,12 @@ def pytest_addoption(parser): default="meta-llama/Llama-Guard-3-1B", help="Specify the safety shield model to use for testing", ) + parser.addoption( + "--embedding-model", + action="store", + default=TEXT_MODEL, + help="Specify the embedding model to use for testing", + ) @pytest.fixture(scope="session") @@ -105,3 +111,9 @@ def pytest_generate_tests(metafunc): [metafunc.config.getoption("--vision-inference-model")], scope="session", ) + if "embedding_model_id" in metafunc.fixturenames: + metafunc.parametrize( + "embedding_model_id", + [metafunc.config.getoption("--embedding-model")], + scope="session", + ) diff --git a/tests/client-sdk/inference/test_embedding.py b/tests/client-sdk/inference/test_embedding.py new file mode 100644 index 000000000..a25382866 --- /dev/null +++ b/tests/client-sdk/inference/test_embedding.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +# +# Test plan: +# +# Types of input: +# - array of a string +# - array of a image (ImageContentItem, either URL or base64 string) +# - array of a text (TextContentItem) +# - array of array of texts, images, or both +# Types of output: +# - list of list of floats +# +# Todo: +# - negative tests +# - empty +# - empty list +# - empty string +# - empty text +# - empty image +# - list of empty texts +# - list of empty images +# - list of empty texts and images +# - long +# - long string +# - long text +# - large image +# - appropriate combinations +# - batch size +# - many inputs +# - invalid +# - invalid URL +# - invalid base64 +# - list of list of strings +# +# Notes: +# - use llama_stack_client fixture +# - use pytest.mark.parametrize when possible +# - no accuracy tests: only check the type of output, not the content +# + +import pytest +from llama_stack_client.types import EmbeddingsResponse +from llama_stack_client.types.shared.interleaved_content import ( + URL, + ImageContentItem, + ImageContentItemImage, + TextContentItem, +) + +DUMMY_STRING = "hello" +DUMMY_STRING2 = "world" +DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text") +DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text") +# TODO(mf): add a real image URL and base64 string +DUMMY_IMAGE_URL = ImageContentItem( + image=ImageContentItemImage(url=URL(uri="https://example.com/image.jpg")), type="image" +) +DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image") + + +@pytest.mark.parametrize( + "contents", + [ + [DUMMY_STRING, DUMMY_STRING2], + [DUMMY_TEXT, DUMMY_TEXT2], + ], + ids=[ + "list[string]", + "list[text]", + ], +) +def test_embedding_text(llama_stack_client, embedding_model_id, contents): + response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) + assert isinstance(response, EmbeddingsResponse) + assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) + assert isinstance(response.embeddings[0], list) + assert isinstance(response.embeddings[0][0], float) + + +@pytest.mark.parametrize( + "contents", + [ + [DUMMY_IMAGE_URL, DUMMY_IMAGE_BASE64], + [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT], + ], + ids=[ + "list[url,base64]", + "list[url,string,base64,text]", + ], +) +@pytest.mark.skip(reason="Media is not supported") +def test_embedding_image(llama_stack_client, embedding_model_id, contents): + response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) + assert isinstance(response, EmbeddingsResponse) + assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) + assert isinstance(response.embeddings[0], list) + assert isinstance(response.embeddings[0][0], float)