diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 8e89bcc72..4d21010b9 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -202,10 +202,15 @@ API responses, specify the adapter here. return self.adapter.provider_data_validator -def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec: +def remote_provider_spec( + api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None +) -> RemoteProviderSpec: + if api_dependencies is None: + api_dependencies = [] return RemoteProviderSpec( api=api, provider_type=f"remote::{adapter.adapter_type}", config_class=adapter.config_class, adapter=adapter, + api_dependencies=api_dependencies, ) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 07fd4af44..43132da72 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -16,12 +16,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.providers.utils.inference.model_registry import build_model_alias from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.utils.inference.embedding_mixin import ( + SentenceTransformerEmbeddingMixin, +) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_media_to_url, request_has_media, ) - from .config import MetaReferenceInferenceConfig from .generation import Llama from .model_parallel import LlamaModelParallelGenerator @@ -32,12 +34,17 @@ log = logging.getLogger(__name__) SEMAPHORE = asyncio.Semaphore(1) -class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate): +class MetaReferenceInferenceImpl( + SentenceTransformerEmbeddingMixin, + Inference, + ModelsProtocolPrivate, +): def __init__(self, config: MetaReferenceInferenceConfig) -> None: self.config = config model = resolve_model(config.model) - ModelRegistryHelper.__init__( - self, + if model is None: + raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") + self.model_registry_helper = ModelRegistryHelper( [ build_model_alias( model.descriptor(), @@ -45,8 +52,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP ) ], ) - if model is None: - raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") self.model = model # verify that the checkpoint actually is for this model lol @@ -76,6 +81,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP async def unregister_model(self, model_id: str) -> None: pass + async def register_model(self, model: Model) -> Model: + model = await self.model_registry_helper.register_model(model) + if model.model_type == ModelType.embedding_model: + self._get_embedding_model(model.provider_resource_id) + return model + async def completion( self, model_id: str, @@ -394,13 +405,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP for x in impl(): yield x - async def embeddings( - self, - model_id: str, - contents: List[InterleavedTextMedia], - ) -> EmbeddingsResponse: - raise NotImplementedError() - async def request_with_localized_media( request: Union[ChatCompletionRequest, CompletionRequest], diff --git a/llama_stack/providers/inline/inference/sentence_transformers/__init__.py b/llama_stack/providers/inline/inference/sentence_transformers/__init__.py new file mode 100644 index 000000000..d5710f7fd --- /dev/null +++ b/llama_stack/providers/inline/inference/sentence_transformers/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.inline.inference.sentence_transformers.config import ( + SentenceTransformersInferenceConfig, +) + + +async def get_provider_impl( + config: SentenceTransformersInferenceConfig, + _deps, +): + from .sentence_transformers import SentenceTransformersInferenceImpl + + impl = SentenceTransformersInferenceImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/inference/sentence_transformers/config.py b/llama_stack/providers/inline/inference/sentence_transformers/config.py new file mode 100644 index 000000000..aec6d56d8 --- /dev/null +++ b/llama_stack/providers/inline/inference/sentence_transformers/config.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class SentenceTransformersInferenceConfig(BaseModel): ... diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py new file mode 100644 index 000000000..18f4d8cd6 --- /dev/null +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +from typing import AsyncGenerator, List, Optional, Union + +from llama_stack.apis.inference import ( + CompletionResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.utils.inference.embedding_mixin import ( + SentenceTransformerEmbeddingMixin, +) +from .config import SentenceTransformersInferenceConfig + +log = logging.getLogger(__name__) + + +class SentenceTransformersInferenceImpl( + SentenceTransformerEmbeddingMixin, + Inference, + ModelsProtocolPrivate, +): + def __init__(self, config: SentenceTransformersInferenceConfig) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + def check_model(self, request) -> None: + if request.model != self.config.model: + raise RuntimeError( + f"Model mismatch: {request.model} != {self.config.model}" + ) + + async def register_model(self, model: Model) -> None: + _ = self._get_embedding_model(model.provider_resource_id) + return model + + async def unregister_model(self, model_id: str) -> None: + pass + + async def completion( + self, + model_id: str, + content: str, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, AsyncGenerator]: + raise NotImplementedError("Sentence transformers don't support completion") + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + raise NotImplementedError("Sentence transformers don't support chat completion") diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 13d463ad8..0ff557b9f 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -18,6 +18,7 @@ META_REFERENCE_DEPS = [ "transformers", "zmq", "lm-format-enforcer", + "sentence-transformers", ] @@ -52,6 +53,13 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.inline.inference.vllm", config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig", ), + InlineProviderSpec( + api=Api.inference, + provider_type="inline::sentence-transformers", + pip_packages=["sentence-transformers"], + module="llama_stack.providers.inline.inference.sentence_transformers", + config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig", + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index f575d9dc3..b11cef22b 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -5,6 +5,7 @@ # the root directory of this source tree. from typing import * # noqa: F403 +import json from botocore.client import BaseClient from llama_models.datatypes import CoreModelId @@ -448,4 +449,18 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + embeddings = [] + for content in contents: + input_text = str(content) if not isinstance(content, str) else content + input_body = {"inputText": input_text} + body = json.dumps(input_body) + response = self.client.invoke_model( + body=body, + modelId=model.provider_resource_id, + accept="application/json", + contentType="application/json", + ) + response_body = json.loads(response.get("body").read()) + embeddings.append(response_body.get("embedding")) + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index 062c1e1ea..e69926942 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -13,7 +13,7 @@ from pydantic import BaseModel, Field @json_schema_type class FireworksImplConfig(BaseModel): url: str = Field( - default="https://api.fireworks.ai/inference", + default="https://api.fireworks.ai/inference/v1", description="The URL for the Fireworks server", ) api_key: Optional[str] = Field( @@ -24,6 +24,6 @@ class FireworksImplConfig(BaseModel): @classmethod def sample_run_config(cls) -> Dict[str, Any]: return { - "url": "https://api.fireworks.ai/inference", + "url": "https://api.fireworks.ai/inference/v1", "api_key": "${env.FIREWORKS_API_KEY}", } diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index c3e634155..e4836ba7c 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks from llama_models.datatypes import CoreModelId @@ -12,6 +12,7 @@ from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer +from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ( @@ -89,19 +90,24 @@ class FireworksInferenceAdapter( async def shutdown(self) -> None: pass - def _get_client(self) -> Fireworks: - fireworks_api_key = None + def _get_api_key(self) -> str: if self.config.api_key is not None: - fireworks_api_key = self.config.api_key + return self.config.api_key else: provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.fireworks_api_key: raise ValueError( 'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": }' ) - fireworks_api_key = provider_data.fireworks_api_key + return provider_data.fireworks_api_key + + def _get_client(self) -> Fireworks: + fireworks_api_key = self._get_api_key() return Fireworks(api_key=fireworks_api_key) + def _get_openai_client(self) -> OpenAI: + return OpenAI(base_url=self.config.url, api_key=self._get_api_key()) + async def completion( self, model_id: str, @@ -264,4 +270,15 @@ class FireworksInferenceAdapter( model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + + client = self._get_openai_client() + kwargs = {} + if model.metadata.get("embedding_dimensions"): + kwargs["dimensions"] = model.metadata.get("embedding_dimensions") + response = client.embeddings.create( + model=model.provider_resource_id, input=contents, **kwargs + ) + + embeddings = [data.embedding for data in response.data] + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index d6fa20835..3e335e854 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -321,9 +321,26 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + + response = await self.client.embed( + model=model.provider_resource_id, input=contents + ) + embeddings = response["embeddings"] + + return EmbeddingsResponse(embeddings=embeddings) async def register_model(self, model: Model) -> Model: + # ollama does not have embedding models running. Check if the model is in list of available models. + if model.model_type == ModelType.embedding_model: + response = await self.client.list() + available_models = [m["model"] for m in response["models"]] + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model '{model.provider_resource_id}' is not available in Ollama. " + f"Available models: {', '.join(available_models)}" + ) + return model model = await self.register_helper.register_model(model) models = await self.client.ps() available_models = [m["model"] for m in models["models"]] diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index e7c96ce98..52079ef85 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -253,4 +253,9 @@ class TogetherInferenceAdapter( model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + r = self._get_client().embeddings.create( + model=model.provider_resource_id, input=contents + ) + embeddings = [item.embedding for item in r.data] + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 57f3db802..ca3d24c6a 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -203,4 +203,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + + kwargs = {} + if model.metadata.get("embedding_dimensions"): + kwargs["dimensions"] = model.metadata.get("embedding_dimensions") + response = self.client.embeddings.create( + model=model.provider_resource_id, input=contents, **kwargs + ) + + embeddings = [data.embedding for data in response.data] + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index 7fe19b403..a195e4f3a 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -18,6 +18,12 @@ def pytest_addoption(parser): default=None, help="Specify the inference model to use for testing", ) + parser.addoption( + "--embedding-model", + action="store", + default=None, + help="Specify the embedding model to use for testing", + ) def pytest_configure(config): @@ -78,3 +84,24 @@ def pytest_generate_tests(metafunc): ): fixtures = [stack.values[0]["inference"] for stack in filtered_stacks] metafunc.parametrize("inference_stack", fixtures, indirect=True) + + if "embedding_model" in metafunc.fixturenames: + model = metafunc.config.getoption("--embedding-model") + if not model: + raise ValueError( + "No embedding model specified. Please provide a valid embedding model." + ) + params = [pytest.param(model, id="")] + + metafunc.parametrize("embedding_model", params, indirect=True) + + if "embedding_stack" in metafunc.fixturenames: + fixtures = INFERENCE_FIXTURES + if filtered_stacks := get_provider_fixture_overrides( + metafunc.config, + { + "inference": INFERENCE_FIXTURES, + }, + ): + fixtures = [stack.values[0]["inference"] for stack in filtered_stacks] + metafunc.parametrize("embedding_stack", fixtures, indirect=True) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 21e122149..20aa94cf0 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -9,9 +9,9 @@ import os import pytest import pytest_asyncio -from llama_stack.apis.models import ModelInput - +from llama_stack.apis.models import ModelInput, ModelType from llama_stack.distribution.datatypes import Api, Provider + from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, ) @@ -37,6 +37,13 @@ def inference_model(request): return request.config.getoption("--inference-model", None) +@pytest.fixture(scope="session") +def embedding_model(request): + if hasattr(request, "param"): + return request.param + return request.config.getoption("--embedding-model", None) + + @pytest.fixture(scope="session") def inference_remote() -> ProviderFixture: return remote_stack_fixture() @@ -85,7 +92,7 @@ def inference_ollama(inference_model) -> ProviderFixture: inference_model = ( [inference_model] if isinstance(inference_model, str) else inference_model ) - if "Llama3.1-8B-Instruct" in inference_model: + if inference_model and "Llama3.1-8B-Instruct" in inference_model: pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing") return ProviderFixture( @@ -240,3 +247,25 @@ async def inference_stack(request, inference_model): ) return test_stack.impls[Api.inference], test_stack.impls[Api.models] + + +@pytest_asyncio.fixture(scope="session") +async def embedding_stack(request, embedding_model): + fixture_name = request.param + inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") + test_stack = await construct_stack_for_test( + [Api.inference], + {"inference": inference_fixture.providers}, + inference_fixture.provider_data, + models=[ + ModelInput( + model_id=embedding_model, + model_type=ModelType.embedding_model, + metadata={ + "embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"), + }, + ) + ], + ) + + return test_stack.impls[Api.inference], test_stack.impls[Api.models] diff --git a/llama_stack/providers/tests/inference/test_embeddings.py b/llama_stack/providers/tests/inference/test_embeddings.py new file mode 100644 index 000000000..03aee0c28 --- /dev/null +++ b/llama_stack/providers/tests/inference/test_embeddings.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.apis.inference import EmbeddingsResponse, ModelType + +# How to run this test: +# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py + + +class TestEmbeddings: + @pytest.mark.asyncio + async def test_embeddings(self, embedding_model, embedding_stack): + inference_impl, models_impl = embedding_stack + model = await models_impl.get_model(embedding_model) + + if model.model_type != ModelType.embedding_model: + pytest.skip("This test is only applicable for embedding models") + + response = await inference_impl.embeddings( + model_id=embedding_model, + contents=["Hello, world!"], + ) + assert isinstance(response, EmbeddingsResponse) + assert len(response.embeddings) > 0 + assert all(isinstance(embedding, list) for embedding in response.embeddings) + assert all( + isinstance(value, float) + for embedding in response.embeddings + for value in embedding + ) + + @pytest.mark.asyncio + async def test_batch_embeddings(self, embedding_model, embedding_stack): + inference_impl, models_impl = embedding_stack + model = await models_impl.get_model(embedding_model) + + if model.model_type != ModelType.embedding_model: + pytest.skip("This test is only applicable for embedding models") + + texts = ["Hello, world!", "This is a test", "Testing embeddings"] + + response = await inference_impl.embeddings( + model_id=embedding_model, + contents=texts, + ) + + assert isinstance(response, EmbeddingsResponse) + assert len(response.embeddings) == len(texts) + assert all(isinstance(embedding, list) for embedding in response.embeddings) + assert all( + isinstance(value, float) + for embedding in response.embeddings + for value in embedding + ) + + embedding_dim = len(response.embeddings[0]) + assert all(len(embedding) == embedding_dim for embedding in response.embeddings) diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py new file mode 100644 index 000000000..141b02398 --- /dev/null +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +from typing import List + +from llama_models.llama3.api.datatypes import InterleavedTextMedia + +from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore + +EMBEDDING_MODELS = {} + + +log = logging.getLogger(__name__) + + +class SentenceTransformerEmbeddingMixin: + model_store: ModelStore + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + model = await self.model_store.get_model(model_id) + embedding_model = self._get_embedding_model(model.provider_resource_id) + embeddings = embedding_model.encode(contents) + return EmbeddingsResponse(embeddings=embeddings) + + def _get_embedding_model(self, model: str) -> "SentenceTransformer": + global EMBEDDING_MODELS + + loaded_model = EMBEDDING_MODELS.get(model) + if loaded_model is not None: + return loaded_model + + log.info(f"Loading sentence transformer for {model}...") + from sentence_transformers import SentenceTransformer + + loaded_model = SentenceTransformer(model) + EMBEDDING_MODELS[model] = loaded_model + return loaded_model