From 96e158eaac4aca62a62afeae40558e053627e547 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 12 Dec 2024 11:47:50 -0800 Subject: [PATCH] Make embedding generation go through inference (#606) This PR does the following: 1) adds the ability to generate embeddings in all supported inference providers. 2) Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models This is a merge from #589 and #598 --- llama_stack/apis/memory_banks/memory_banks.py | 1 + llama_stack/apis/models/models.py | 11 ++- llama_stack/distribution/routers/routers.py | 24 +++++- .../distribution/routers/routing_tables.py | 44 +++++++--- llama_stack/distribution/store/registry.py | 2 +- llama_stack/providers/datatypes.py | 5 +- .../inference/meta_reference/inference.py | 30 ++++--- .../sentence_transformers/__init__.py | 20 +++++ .../inference/sentence_transformers/config.py | 10 +++ .../sentence_transformers.py | 74 +++++++++++++++++ .../providers/inline/memory/faiss/__init__.py | 7 +- .../providers/inline/memory/faiss/faiss.py | 41 ++++++---- llama_stack/providers/registry/inference.py | 8 ++ llama_stack/providers/registry/memory.py | 7 ++ .../remote/inference/bedrock/bedrock.py | 22 ++++- .../remote/inference/fireworks/config.py | 4 +- .../remote/inference/fireworks/fireworks.py | 30 +++++-- .../remote/inference/ollama/ollama.py | 24 +++++- .../remote/inference/together/together.py | 12 ++- .../providers/remote/inference/vllm/vllm.py | 19 ++++- .../remote/memory/chroma/__init__.py | 10 ++- .../providers/remote/memory/chroma/chroma.py | 18 +++-- .../remote/memory/pgvector/__init__.py | 8 +- .../remote/memory/pgvector/pgvector.py | 38 ++++----- .../remote/memory/qdrant/__init__.py | 8 +- .../providers/remote/memory/qdrant/qdrant.py | 5 +- .../remote/memory/weaviate/__init__.py | 8 +- .../remote/memory/weaviate/weaviate.py | 27 +++++-- .../providers/tests/inference/conftest.py | 6 ++ .../providers/tests/inference/fixtures.py | 23 +++++- .../tests/inference/test_embeddings.py | 62 ++++++++++++++ .../providers/tests/memory/conftest.py | 80 +++++++++++++++++-- .../providers/tests/memory/fixtures.py | 30 +++++-- .../providers/tests/memory/test_memory.py | 26 +++--- .../utils/inference/embedding_mixin.py | 47 +++++++++++ .../utils/inference/model_registry.py | 9 ++- .../providers/utils/memory/vector_store.py | 33 +++----- 37 files changed, 677 insertions(+), 156 deletions(-) create mode 100644 llama_stack/providers/inline/inference/sentence_transformers/__init__.py create mode 100644 llama_stack/providers/inline/inference/sentence_transformers/config.py create mode 100644 llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py create mode 100644 llama_stack/providers/tests/inference/test_embeddings.py create mode 100644 llama_stack/providers/utils/inference/embedding_mixin.py diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index a17e8e48d..b037dfa66 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -89,6 +89,7 @@ class VectorMemoryBank(MemoryBankResourceMixin): memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value embedding_model: str chunk_size_in_tokens: int + embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2 overlap_size_in_tokens: Optional[int] = None diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index cb9cb1117..71101ec8b 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod @@ -20,6 +21,11 @@ class CommonModelFields(BaseModel): ) +class ModelType(Enum): + llm = "llm" + embedding_model = "embedding" + + @json_schema_type class Model(CommonModelFields, Resource): type: Literal[ResourceType.model.value] = ResourceType.model.value @@ -34,12 +40,14 @@ class Model(CommonModelFields, Resource): model_config = ConfigDict(protected_namespaces=()) + model_type: ModelType = Field(default=ModelType.llm) + class ModelInput(CommonModelFields): model_id: str provider_id: Optional[str] = None provider_model_id: Optional[str] = None - + model_type: Optional[ModelType] = ModelType.llm model_config = ConfigDict(protected_namespaces=()) @@ -59,6 +67,7 @@ class Models(Protocol): provider_model_id: Optional[str] = None, provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + model_type: Optional[ModelType] = None, ) -> Model: ... @webmethod(route="/models/unregister", method="POST") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5b75a525b..51be318cb 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -88,9 +88,10 @@ class InferenceRouter(Inference): provider_model_id: Optional[str] = None, provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + model_type: Optional[ModelType] = None, ) -> None: await self.routing_table.register_model( - model_id, provider_model_id, provider_id, metadata + model_id, provider_model_id, provider_id, metadata, model_type ) async def chat_completion( @@ -105,6 +106,13 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.routing_table.get_model(model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + if model.model_type == ModelType.embedding_model: + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) params = dict( model_id=model_id, messages=messages, @@ -131,6 +139,13 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.routing_table.get_model(model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + if model.model_type == ModelType.embedding_model: + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) provider = self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -150,6 +165,13 @@ class InferenceRouter(Inference): model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: + model = await self.routing_table.get_model(model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + if model.model_type == ModelType.llm: + raise ValueError( + f"Model '{model_id}' is an LLM model and does not support embeddings" + ) return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 2fb5a5e1c..bc3de8be0 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -209,6 +209,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_model_id: Optional[str] = None, provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + model_type: Optional[ModelType] = None, ) -> Model: if provider_model_id is None: provider_model_id = model_id @@ -222,11 +223,21 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): ) if metadata is None: metadata = {} + if model_type is None: + model_type = ModelType.llm + if ( + "embedding_dimension" not in metadata + and model_type == ModelType.embedding_model + ): + raise ValueError( + "Embedding model must have an embedding dimension in its metadata" + ) model = Model( identifier=model_id, provider_resource_id=provider_model_id, provider_id=provider_id, metadata=metadata, + model_type=model_type, ) registered_model = await self.register_object(model) return registered_model @@ -298,16 +309,29 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." ) - memory_bank = parse_obj_as( - MemoryBank, - { - "identifier": memory_bank_id, - "type": ResourceType.memory_bank.value, - "provider_id": provider_id, - "provider_resource_id": provider_memory_bank_id, - **params.model_dump(), - }, - ) + model = await self.get_object_by_identifier("model", params.embedding_model) + if model is None: + raise ValueError(f"Model {params.embedding_model} not found") + if model.model_type != ModelType.embedding_model: + raise ValueError( + f"Model {params.embedding_model} is not an embedding model" + ) + if "embedding_dimension" not in model.metadata: + raise ValueError( + f"Model {params.embedding_model} does not have an embedding dimension" + ) + memory_bank_data = { + "identifier": memory_bank_id, + "type": ResourceType.memory_bank.value, + "provider_id": provider_id, + "provider_resource_id": provider_memory_bank_id, + **params.model_dump(), + } + if params.memory_bank_type == MemoryBankType.vector.value: + memory_bank_data["embedding_dimension"] = model.metadata[ + "embedding_dimension" + ] + memory_bank = parse_obj_as(MemoryBank, memory_bank_data) await self.register_object(memory_bank) return memory_bank diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 041a5677c..8f93c0c4b 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -40,7 +40,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v2" +KEY_VERSION = "v3" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 241497050..27490954b 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -200,10 +200,13 @@ API responses, specify the adapter here. return self.adapter.provider_data_validator -def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec: +def remote_provider_spec( + api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None +) -> RemoteProviderSpec: return RemoteProviderSpec( api=api, provider_type=f"remote::{adapter.adapter_type}", config_class=adapter.config_class, adapter=adapter, + api_dependencies=api_dependencies or [], ) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 07fd4af44..e7abde227 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -16,12 +16,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.providers.utils.inference.model_registry import build_model_alias from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.utils.inference.embedding_mixin import ( + SentenceTransformerEmbeddingMixin, +) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_media_to_url, request_has_media, ) - from .config import MetaReferenceInferenceConfig from .generation import Llama from .model_parallel import LlamaModelParallelGenerator @@ -32,12 +34,17 @@ log = logging.getLogger(__name__) SEMAPHORE = asyncio.Semaphore(1) -class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate): +class MetaReferenceInferenceImpl( + SentenceTransformerEmbeddingMixin, + Inference, + ModelsProtocolPrivate, +): def __init__(self, config: MetaReferenceInferenceConfig) -> None: self.config = config model = resolve_model(config.model) - ModelRegistryHelper.__init__( - self, + if model is None: + raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") + self.model_registry_helper = ModelRegistryHelper( [ build_model_alias( model.descriptor(), @@ -45,8 +52,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP ) ], ) - if model is None: - raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") self.model = model # verify that the checkpoint actually is for this model lol @@ -76,6 +81,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP async def unregister_model(self, model_id: str) -> None: pass + async def register_model(self, model: Model) -> Model: + model = await self.model_registry_helper.register_model(model) + if model.model_type == ModelType.embedding_model: + self._load_sentence_transformer_model(model.provider_resource_id) + return model + async def completion( self, model_id: str, @@ -394,13 +405,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP for x in impl(): yield x - async def embeddings( - self, - model_id: str, - contents: List[InterleavedTextMedia], - ) -> EmbeddingsResponse: - raise NotImplementedError() - async def request_with_localized_media( request: Union[ChatCompletionRequest, CompletionRequest], diff --git a/llama_stack/providers/inline/inference/sentence_transformers/__init__.py b/llama_stack/providers/inline/inference/sentence_transformers/__init__.py new file mode 100644 index 000000000..d5710f7fd --- /dev/null +++ b/llama_stack/providers/inline/inference/sentence_transformers/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.inline.inference.sentence_transformers.config import ( + SentenceTransformersInferenceConfig, +) + + +async def get_provider_impl( + config: SentenceTransformersInferenceConfig, + _deps, +): + from .sentence_transformers import SentenceTransformersInferenceImpl + + impl = SentenceTransformersInferenceImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/inference/sentence_transformers/config.py b/llama_stack/providers/inline/inference/sentence_transformers/config.py new file mode 100644 index 000000000..aec6d56d8 --- /dev/null +++ b/llama_stack/providers/inline/inference/sentence_transformers/config.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class SentenceTransformersInferenceConfig(BaseModel): ... diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py new file mode 100644 index 000000000..0896b44af --- /dev/null +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +from typing import AsyncGenerator, List, Optional, Union + +from llama_stack.apis.inference import ( + CompletionResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.utils.inference.embedding_mixin import ( + SentenceTransformerEmbeddingMixin, +) +from .config import SentenceTransformersInferenceConfig + +log = logging.getLogger(__name__) + + +class SentenceTransformersInferenceImpl( + SentenceTransformerEmbeddingMixin, + Inference, + ModelsProtocolPrivate, +): + def __init__(self, config: SentenceTransformersInferenceConfig) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def register_model(self, model: Model) -> None: + _ = self._load_sentence_transformer_model(model.provider_resource_id) + return model + + async def unregister_model(self, model_id: str) -> None: + pass + + async def completion( + self, + model_id: str, + content: str, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, AsyncGenerator]: + raise ValueError("Sentence transformers don't support completion") + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + raise ValueError("Sentence transformers don't support chat completion") diff --git a/llama_stack/providers/inline/memory/faiss/__init__.py b/llama_stack/providers/inline/memory/faiss/__init__.py index 16c383be3..2d7ede3b1 100644 --- a/llama_stack/providers/inline/memory/faiss/__init__.py +++ b/llama_stack/providers/inline/memory/faiss/__init__.py @@ -4,16 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec from .config import FaissImplConfig -async def get_provider_impl(config: FaissImplConfig, _deps): +async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]): from .faiss import FaissMemoryImpl assert isinstance( config, FaissImplConfig ), f"Unexpected config type: {type(config)}" - impl = FaissMemoryImpl(config) + impl = FaissMemoryImpl(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index 78de13120..7c27aca85 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -19,11 +19,10 @@ from numpy.typing import NDArray from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.vector_store import ( - ALL_MINILM_L6_V2_DIMENSION, BankWithIndex, EmbeddingIndex, ) @@ -32,7 +31,8 @@ from .config import FaissImplConfig logger = logging.getLogger(__name__) -MEMORY_BANKS_PREFIX = "memory_banks:v1::" +MEMORY_BANKS_PREFIX = "memory_banks:v2::" +FAISS_INDEX_PREFIX = "faiss_index:v2::" class FaissIndex(EmbeddingIndex): @@ -56,7 +56,7 @@ class FaissIndex(EmbeddingIndex): if not self.kvstore: return - index_key = f"faiss_index:v1::{self.bank_id}" + index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}" stored_data = await self.kvstore.get(index_key) if stored_data: @@ -85,16 +85,25 @@ class FaissIndex(EmbeddingIndex): "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"), } - index_key = f"faiss_index:v1::{self.bank_id}" + index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}" await self.kvstore.set(key=index_key, value=json.dumps(data)) async def delete(self): if not self.kvstore or not self.bank_id: return - await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}") + await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}") async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + # Add dimension check + embedding_dim = ( + embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0] + ) + if embedding_dim != self.index.d: + raise ValueError( + f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}" + ) + indexlen = len(self.id_by_index) for i, chunk in enumerate(chunks): self.chunk_by_index[indexlen + i] = chunk @@ -124,8 +133,9 @@ class FaissIndex(EmbeddingIndex): class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): - def __init__(self, config: FaissImplConfig) -> None: + def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None: self.config = config + self.inference_api = inference_api self.cache = {} self.kvstore = None @@ -139,10 +149,11 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): for bank_data in stored_banks: bank = VectorMemoryBank.model_validate_json(bank_data) index = BankWithIndex( - bank=bank, - index=await FaissIndex.create( - ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier + bank, + await FaissIndex.create( + bank.embedding_dimension, self.kvstore, bank.identifier ), + self.inference_api, ) self.cache[bank.identifier] = index @@ -166,13 +177,13 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): ) # Store in cache - index = BankWithIndex( - bank=memory_bank, - index=await FaissIndex.create( - ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, + await FaissIndex.create( + memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier ), + self.inference_api, ) - self.cache[memory_bank.identifier] = index async def list_memory_banks(self) -> List[MemoryBank]: return [i.bank for i in self.cache.values()] diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 13d463ad8..0ff557b9f 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -18,6 +18,7 @@ META_REFERENCE_DEPS = [ "transformers", "zmq", "lm-format-enforcer", + "sentence-transformers", ] @@ -52,6 +53,13 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.inline.inference.vllm", config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig", ), + InlineProviderSpec( + api=Api.inference, + provider_type="inline::sentence-transformers", + pip_packages=["sentence-transformers"], + module="llama_stack.providers.inline.inference.sentence_transformers", + config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig", + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index c52aba6c6..27c07e007 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.inline.memory.faiss", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", deprecation_warning="Please use the `inline::faiss` provider instead.", + api_dependencies=[Api.inference], ), InlineProviderSpec( api=Api.memory, @@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], module="llama_stack.providers.inline.memory.faiss", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", + api_dependencies=[Api.inference], ), remote_provider_spec( Api.memory, @@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.chroma", config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig", ), + api_dependencies=[Api.inference], ), InlineProviderSpec( api=Api.memory, @@ -71,6 +74,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.pgvector", config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig", ), + api_dependencies=[Api.inference], ), remote_provider_spec( Api.memory, @@ -81,6 +85,7 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig", provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData", ), + api_dependencies=[Api.inference], ), remote_provider_spec( api=Api.memory, @@ -90,6 +95,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.sample", config_class="llama_stack.providers.remote.memory.sample.SampleConfig", ), + api_dependencies=[], ), remote_provider_spec( Api.memory, @@ -99,5 +105,6 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.qdrant", config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig", ), + api_dependencies=[Api.inference], ), ] diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index f575d9dc3..96cbcaa67 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -5,6 +5,7 @@ # the root directory of this source tree. from typing import * # noqa: F403 +import json from botocore.client import BaseClient from llama_models.datatypes import CoreModelId @@ -19,8 +20,10 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.apis.inference import * # noqa: F403 + from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.utils.bedrock.client import create_bedrock_client +from llama_stack.providers.utils.inference.prompt_adapter import content_has_media model_aliases = [ @@ -448,4 +451,21 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + embeddings = [] + for content in contents: + assert not content_has_media( + content + ), "Bedrock does not support media for embeddings" + input_text = interleaved_text_media_as_str(content) + input_body = {"inputText": input_text} + body = json.dumps(input_body) + response = self.client.invoke_model( + body=body, + modelId=model.provider_resource_id, + accept="application/json", + contentType="application/json", + ) + response_body = json.loads(response.get("body").read()) + embeddings.append(response_body.get("embedding")) + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index 062c1e1ea..e69926942 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -13,7 +13,7 @@ from pydantic import BaseModel, Field @json_schema_type class FireworksImplConfig(BaseModel): url: str = Field( - default="https://api.fireworks.ai/inference", + default="https://api.fireworks.ai/inference/v1", description="The URL for the Fireworks server", ) api_key: Optional[str] = Field( @@ -24,6 +24,6 @@ class FireworksImplConfig(BaseModel): @classmethod def sample_run_config(cls) -> Dict[str, Any]: return { - "url": "https://api.fireworks.ai/inference", + "url": "https://api.fireworks.ai/inference/v1", "api_key": "${env.FIREWORKS_API_KEY}", } diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index c3e634155..b0e93305e 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks from llama_models.datatypes import CoreModelId @@ -28,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + content_has_media, convert_message_to_dict, request_has_media, ) @@ -89,17 +90,19 @@ class FireworksInferenceAdapter( async def shutdown(self) -> None: pass - def _get_client(self) -> Fireworks: - fireworks_api_key = None + def _get_api_key(self) -> str: if self.config.api_key is not None: - fireworks_api_key = self.config.api_key + return self.config.api_key else: provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.fireworks_api_key: raise ValueError( 'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": }' ) - fireworks_api_key = provider_data.fireworks_api_key + return provider_data.fireworks_api_key + + def _get_client(self) -> Fireworks: + fireworks_api_key = self._get_api_key() return Fireworks(api_key=fireworks_api_key) async def completion( @@ -264,4 +267,19 @@ class FireworksInferenceAdapter( model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + + kwargs = {} + if model.metadata.get("embedding_dimensions"): + kwargs["dimensions"] = model.metadata.get("embedding_dimensions") + assert all( + not content_has_media(content) for content in contents + ), "Fireworks does not support media for embeddings" + response = self._get_client().embeddings.create( + model=model.provider_resource_id, + input=[interleaved_text_media_as_str(content) for content in contents], + **kwargs, + ) + + embeddings = [data.embedding for data in response.data] + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index d6fa20835..1ba4ad599 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + content_has_media, convert_image_media_to_url, request_has_media, ) @@ -321,9 +322,30 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + + assert all( + not content_has_media(content) for content in contents + ), "Ollama does not support media for embeddings" + response = await self.client.embed( + model=model.provider_resource_id, + input=[interleaved_text_media_as_str(content) for content in contents], + ) + embeddings = response["embeddings"] + + return EmbeddingsResponse(embeddings=embeddings) async def register_model(self, model: Model) -> Model: + # ollama does not have embedding models running. Check if the model is in list of available models. + if model.model_type == ModelType.embedding_model: + response = await self.client.list() + available_models = [m["model"] for m in response["models"]] + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model '{model.provider_resource_id}' is not available in Ollama. " + f"Available models: {', '.join(available_models)}" + ) + return model model = await self.register_helper.register_model(model) models = await self.client.ps() available_models = [m["model"] for m in models["models"]] diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index e7c96ce98..7cd798d16 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + content_has_media, convert_message_to_dict, request_has_media, ) @@ -253,4 +254,13 @@ class TogetherInferenceAdapter( model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + assert all( + not content_has_media(content) for content in contents + ), "Together does not support media for embeddings" + r = self._get_client().embeddings.create( + model=model.provider_resource_id, + input=[interleaved_text_media_as_str(content) for content in contents], + ) + embeddings = [item.embedding for item in r.data] + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 57f3db802..7ad5cef0f 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + content_has_media, convert_message_to_dict, request_has_media, ) @@ -203,4 +204,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + + kwargs = {} + assert model.model_type == ModelType.embedding_model + assert model.metadata.get("embedding_dimensions") + kwargs["dimensions"] = model.metadata.get("embedding_dimensions") + assert all( + not content_has_media(content) for content in contents + ), "VLLM does not support media for embeddings" + response = self.client.embeddings.create( + model=model.provider_resource_id, + input=[interleaved_text_media_as_str(content) for content in contents], + **kwargs, + ) + + embeddings = [data.embedding for data in response.data] + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/remote/memory/chroma/__init__.py b/llama_stack/providers/remote/memory/chroma/__init__.py index 63e9eae7d..581d60e75 100644 --- a/llama_stack/providers/remote/memory/chroma/__init__.py +++ b/llama_stack/providers/remote/memory/chroma/__init__.py @@ -4,12 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import ChromaRemoteImplConfig -async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps): +async def get_adapter_impl( + config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec] +): from .chroma import ChromaMemoryAdapter - impl = ChromaMemoryAdapter(config) + impl = ChromaMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index f4fb50a7c..20c81da3e 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -13,8 +13,7 @@ import chromadb from numpy.typing import NDArray from llama_stack.apis.memory import * # noqa: F403 - -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, @@ -87,10 +86,14 @@ class ChromaIndex(EmbeddingIndex): class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): def __init__( - self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig] + self, + config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig], + inference_api: Api.inference, ) -> None: log.info(f"Initializing ChromaMemoryAdapter with url: {config}") self.config = config + self.inference_api = inference_api + self.client = None self.cache = {} @@ -127,10 +130,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): metadata={"bank": memory_bank.model_dump_json()}, ) ) - bank_index = BankWithIndex( - bank=memory_bank, index=ChromaIndex(self.client, collection) + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, ChromaIndex(self.client, collection), self.inference_api ) - self.cache[memory_bank.identifier] = bank_index async def unregister_memory_bank(self, memory_bank_id: str) -> None: await self.cache[memory_bank_id].index.delete() @@ -166,6 +168,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): collection = await maybe_await(self.client.get_collection(bank_id)) if not collection: raise ValueError(f"Bank {bank_id} not found in Chroma") - index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection)) + index = BankWithIndex( + bank, ChromaIndex(self.client, collection), self.inference_api + ) self.cache[bank_id] = index return index diff --git a/llama_stack/providers/remote/memory/pgvector/__init__.py b/llama_stack/providers/remote/memory/pgvector/__init__.py index 4ac30452f..b4620cae0 100644 --- a/llama_stack/providers/remote/memory/pgvector/__init__.py +++ b/llama_stack/providers/remote/memory/pgvector/__init__.py @@ -4,12 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import PGVectorConfig -async def get_adapter_impl(config: PGVectorConfig, _deps): +async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]): from .pgvector import PGVectorMemoryAdapter - impl = PGVectorMemoryAdapter(config) + impl = PGVectorMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index 9ec76e8ca..0f295f38a 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -16,9 +16,9 @@ from pydantic import BaseModel, parse_obj_as from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate + from llama_stack.providers.utils.memory.vector_store import ( - ALL_MINILM_L6_V2_DIMENSION, BankWithIndex, EmbeddingIndex, ) @@ -120,8 +120,9 @@ class PGVectorIndex(EmbeddingIndex): class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): - def __init__(self, config: PGVectorConfig) -> None: + def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None: self.config = config + self.inference_api = inference_api self.cursor = None self.conn = None self.cache = {} @@ -160,27 +161,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def shutdown(self) -> None: pass - async def register_memory_bank( - self, - memory_bank: MemoryBank, - ) -> None: + async def register_memory_bank(self, memory_bank: MemoryBank) -> None: assert ( memory_bank.memory_bank_type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.memory_bank_type}" - upsert_models( - self.cursor, - [ - (memory_bank.identifier, memory_bank), - ], + upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)]) + index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor) + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, index, self.inference_api ) - index = BankWithIndex( - bank=memory_bank, - index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), - ) - self.cache[memory_bank.identifier] = index - async def unregister_memory_bank(self, memory_bank_id: str) -> None: await self.cache[memory_bank_id].index.delete() del self.cache[memory_bank_id] @@ -203,14 +194,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): index = await self._get_and_cache_bank_index(bank_id) return await index.query_documents(query, params) + self.inference_api = inference_api + async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: if bank_id in self.cache: return self.cache[bank_id] bank = await self.memory_bank_store.get_memory_bank(bank_id) - index = BankWithIndex( - bank=bank, - index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), - ) - self.cache[bank_id] = index - return index + index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor) + self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api) + return self.cache[bank_id] diff --git a/llama_stack/providers/remote/memory/qdrant/__init__.py b/llama_stack/providers/remote/memory/qdrant/__init__.py index 9f54babad..54605fcf9 100644 --- a/llama_stack/providers/remote/memory/qdrant/__init__.py +++ b/llama_stack/providers/remote/memory/qdrant/__init__.py @@ -4,12 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import QdrantConfig -async def get_adapter_impl(config: QdrantConfig, _deps): +async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]): from .qdrant import QdrantVectorMemoryAdapter - impl = QdrantVectorMemoryAdapter(config) + impl = QdrantVectorMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/qdrant/qdrant.py b/llama_stack/providers/remote/memory/qdrant/qdrant.py index a9badbd6a..0f1a7c7d1 100644 --- a/llama_stack/providers/remote/memory/qdrant/qdrant.py +++ b/llama_stack/providers/remote/memory/qdrant/qdrant.py @@ -101,10 +101,11 @@ class QdrantIndex(EmbeddingIndex): class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): - def __init__(self, config: QdrantConfig) -> None: + def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None: self.config = config self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.cache = {} + self.inference_api = inference_api async def initialize(self) -> None: pass @@ -123,6 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): index = BankWithIndex( bank=memory_bank, index=QdrantIndex(self.client, memory_bank.identifier), + inference_api=self.inference_api, ) self.cache[memory_bank.identifier] = index @@ -138,6 +140,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): index = BankWithIndex( bank=bank, index=QdrantIndex(client=self.client, collection_name=bank_id), + inference_api=self.inference_api, ) self.cache[bank_id] = index return index diff --git a/llama_stack/providers/remote/memory/weaviate/__init__.py b/llama_stack/providers/remote/memory/weaviate/__init__.py index 504bd1508..f7120bec0 100644 --- a/llama_stack/providers/remote/memory/weaviate/__init__.py +++ b/llama_stack/providers/remote/memory/weaviate/__init__.py @@ -4,12 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401 -async def get_adapter_impl(config: WeaviateConfig, _deps): +async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]): from .weaviate import WeaviateMemoryAdapter - impl = WeaviateMemoryAdapter(config) + impl = WeaviateMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py index f05fc663e..510915e65 100644 --- a/llama_stack/providers/remote/memory/weaviate/weaviate.py +++ b/llama_stack/providers/remote/memory/weaviate/weaviate.py @@ -12,10 +12,11 @@ import weaviate import weaviate.classes as wvc from numpy.typing import NDArray from weaviate.classes.init import Auth +from weaviate.classes.query import Filter from llama_stack.apis.memory import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, @@ -80,12 +81,21 @@ class WeaviateIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) + async def delete(self, chunk_ids: List[str]) -> None: + collection = self.client.collections.get(self.collection_name) + collection.data.delete_many( + where=Filter.by_property("id").contains_any(chunk_ids) + ) + class WeaviateMemoryAdapter( - Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate + Memory, + NeedsRequestProviderData, + MemoryBanksProtocolPrivate, ): - def __init__(self, config: WeaviateConfig) -> None: + def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None: self.config = config + self.inference_api = inference_api self.client_cache = {} self.cache = {} @@ -117,7 +127,7 @@ class WeaviateMemoryAdapter( memory_bank: MemoryBank, ) -> None: assert ( - memory_bank.memory_bank_type == MemoryBankType.vector + memory_bank.memory_bank_type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.memory_bank_type}" client = self._get_client() @@ -135,11 +145,11 @@ class WeaviateMemoryAdapter( ], ) - index = BankWithIndex( - bank=memory_bank, - index=WeaviateIndex(client=client, collection_name=memory_bank.identifier), + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, + WeaviateIndex(client=client, collection_name=memory_bank.identifier), + self.inference_api, ) - self.cache[memory_bank.identifier] = index async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: if bank_id in self.cache: @@ -156,6 +166,7 @@ class WeaviateMemoryAdapter( index = BankWithIndex( bank=bank, index=WeaviateIndex(client=client, collection_name=bank_id), + inference_api=self.inference_api, ) self.cache[bank_id] = index return index diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index 7fe19b403..54ebcd83a 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -18,6 +18,12 @@ def pytest_addoption(parser): default=None, help="Specify the inference model to use for testing", ) + parser.addoption( + "--embedding-model", + action="store", + default=None, + help="Specify the embedding model to use for testing", + ) def pytest_configure(config): diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 21e122149..ed0b0302d 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -9,9 +9,9 @@ import os import pytest import pytest_asyncio -from llama_stack.apis.models import ModelInput - +from llama_stack.apis.models import ModelInput, ModelType from llama_stack.distribution.datatypes import Api, Provider + from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, ) @@ -47,6 +47,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture: inference_model = ( [inference_model] if isinstance(inference_model, str) else inference_model ) + # If embedding dimension is set, use the 8B model for testing + if os.getenv("EMBEDDING_DIMENSION"): + inference_model = ["meta-llama/Llama-3.1-8B-Instruct"] return ProviderFixture( providers=[ @@ -85,7 +88,7 @@ def inference_ollama(inference_model) -> ProviderFixture: inference_model = ( [inference_model] if isinstance(inference_model, str) else inference_model ) - if "Llama3.1-8B-Instruct" in inference_model: + if inference_model and "Llama3.1-8B-Instruct" in inference_model: pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing") return ProviderFixture( @@ -232,11 +235,23 @@ INFERENCE_FIXTURES = [ async def inference_stack(request, inference_model): fixture_name = request.param inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") + model_type = ModelType.llm + metadata = {} + if os.getenv("EMBEDDING_DIMENSION"): + model_type = ModelType.embedding_model + metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION") + test_stack = await construct_stack_for_test( [Api.inference], {"inference": inference_fixture.providers}, inference_fixture.provider_data, - models=[ModelInput(model_id=inference_model)], + models=[ + ModelInput( + model_id=inference_model, + model_type=model_type, + metadata=metadata, + ) + ], ) return test_stack.impls[Api.inference], test_stack.impls[Api.models] diff --git a/llama_stack/providers/tests/inference/test_embeddings.py b/llama_stack/providers/tests/inference/test_embeddings.py new file mode 100644 index 000000000..3502c6b20 --- /dev/null +++ b/llama_stack/providers/tests/inference/test_embeddings.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.apis.inference import EmbeddingsResponse, ModelType + +# How to run this test: +# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py + + +class TestEmbeddings: + @pytest.mark.asyncio + async def test_embeddings(self, inference_model, inference_stack): + inference_impl, models_impl = inference_stack + model = await models_impl.get_model(inference_model) + + if model.model_type != ModelType.embedding_model: + pytest.skip("This test is only applicable for embedding models") + + response = await inference_impl.embeddings( + model_id=inference_model, + contents=["Hello, world!"], + ) + assert isinstance(response, EmbeddingsResponse) + assert len(response.embeddings) > 0 + assert all(isinstance(embedding, list) for embedding in response.embeddings) + assert all( + isinstance(value, float) + for embedding in response.embeddings + for value in embedding + ) + + @pytest.mark.asyncio + async def test_batch_embeddings(self, inference_model, inference_stack): + inference_impl, models_impl = inference_stack + model = await models_impl.get_model(inference_model) + + if model.model_type != ModelType.embedding_model: + pytest.skip("This test is only applicable for embedding models") + + texts = ["Hello, world!", "This is a test", "Testing embeddings"] + + response = await inference_impl.embeddings( + model_id=inference_model, + contents=texts, + ) + + assert isinstance(response, EmbeddingsResponse) + assert len(response.embeddings) == len(texts) + assert all(isinstance(embedding, list) for embedding in response.embeddings) + assert all( + isinstance(value, float) + for embedding in response.embeddings + for value in embedding + ) + + embedding_dim = len(response.embeddings[0]) + assert all(len(embedding) == embedding_dim for embedding in response.embeddings) diff --git a/llama_stack/providers/tests/memory/conftest.py b/llama_stack/providers/tests/memory/conftest.py index 99ecbe794..7595538eb 100644 --- a/llama_stack/providers/tests/memory/conftest.py +++ b/llama_stack/providers/tests/memory/conftest.py @@ -6,9 +6,65 @@ import pytest +from ..conftest import get_provider_fixture_overrides + +from ..inference.fixtures import INFERENCE_FIXTURES from .fixtures import MEMORY_FIXTURES +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "inference": "meta_reference", + "memory": "faiss", + }, + id="meta_reference", + marks=pytest.mark.meta_reference, + ), + pytest.param( + { + "inference": "ollama", + "memory": "pgvector", + }, + id="ollama", + marks=pytest.mark.ollama, + ), + pytest.param( + { + "inference": "together", + "memory": "chroma", + }, + id="chroma", + marks=pytest.mark.chroma, + ), + pytest.param( + { + "inference": "bedrock", + "memory": "qdrant", + }, + id="qdrant", + marks=pytest.mark.qdrant, + ), + pytest.param( + { + "inference": "fireworks", + "memory": "weaviate", + }, + id="weaviate", + marks=pytest.mark.weaviate, + ), +] + + +def pytest_addoption(parser): + parser.addoption( + "--inference-model", + action="store", + default=None, + help="Specify the inference model to use for testing", + ) + + def pytest_configure(config): for fixture_name in MEMORY_FIXTURES: config.addinivalue_line( @@ -18,12 +74,22 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): + if "inference_model" in metafunc.fixturenames: + model = metafunc.config.getoption("--inference-model") + if not model: + raise ValueError( + "No inference model specified. Please provide a valid inference model." + ) + params = [pytest.param(model, id="")] + + metafunc.parametrize("inference_model", params, indirect=True) if "memory_stack" in metafunc.fixturenames: - metafunc.parametrize( - "memory_stack", - [ - pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) - for fixture_name in MEMORY_FIXTURES - ], - indirect=True, + available_fixtures = { + "inference": INFERENCE_FIXTURES, + "memory": MEMORY_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS ) + metafunc.parametrize("memory_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index cc57bb916..92fd1720e 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,6 +10,8 @@ import tempfile import pytest import pytest_asyncio +from llama_stack.apis.inference import ModelInput, ModelType + from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig @@ -105,14 +107,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"] @pytest_asyncio.fixture(scope="session") -async def memory_stack(request): - fixture_name = request.param - fixture = request.getfixturevalue(f"memory_{fixture_name}") +async def memory_stack(inference_model, request): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["inference", "memory"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) test_stack = await construct_stack_for_test( - [Api.memory], - {"memory": fixture.providers}, - fixture.provider_data, + [Api.memory, Api.inference], + providers, + provider_data, + models=[ + ModelInput( + model_id=inference_model, + model_type=ModelType.embedding_model, + metadata={ + "embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"), + }, + ) + ], ) return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks] diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index b6e2e0a76..03597d073 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -45,12 +45,14 @@ def sample_documents(): ] -async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank: +async def register_memory_bank( + banks_impl: MemoryBanks, inference_model: str +) -> MemoryBank: bank_id = f"test_bank_{uuid.uuid4().hex}" return await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", + embedding_model=inference_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank: class TestMemory: @pytest.mark.asyncio - async def test_banks_list(self, memory_stack): + async def test_banks_list(self, memory_stack, inference_model): _, banks_impl = memory_stack # Register a test bank - registered_bank = await register_memory_bank(banks_impl) + registered_bank = await register_memory_bank(banks_impl, inference_model) try: # Verify our bank shows up in list @@ -84,7 +86,7 @@ class TestMemory: ) @pytest.mark.asyncio - async def test_banks_register(self, memory_stack): + async def test_banks_register(self, memory_stack, inference_model): _, banks_impl = memory_stack bank_id = f"test_bank_{uuid.uuid4().hex}" @@ -94,7 +96,7 @@ class TestMemory: await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", + embedding_model=inference_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -109,7 +111,7 @@ class TestMemory: await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", + embedding_model=inference_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -126,13 +128,15 @@ class TestMemory: await banks_impl.unregister_memory_bank(bank_id) @pytest.mark.asyncio - async def test_query_documents(self, memory_stack, sample_documents): + async def test_query_documents( + self, memory_stack, inference_model, sample_documents + ): memory_impl, banks_impl = memory_stack with pytest.raises(ValueError): await memory_impl.insert_documents("test_bank", sample_documents) - registered_bank = await register_memory_bank(banks_impl) + registered_bank = await register_memory_bank(banks_impl, inference_model) await memory_impl.insert_documents( registered_bank.memory_bank_id, sample_documents ) @@ -165,13 +169,13 @@ class TestMemory: # Test case 5: Query with threshold on similarity score query5 = "quantum computing" # Not directly related to any document - params5 = {"score_threshold": 0.2} + params5 = {"score_threshold": 0.01} response5 = await memory_impl.query_documents( registered_bank.memory_bank_id, query5, params5 ) assert_valid_response(response5) print("The scores are:", response5.scores) - assert all(score >= 0.2 for score in response5.scores) + assert all(score >= 0.01 for score in response5.scores) def assert_valid_response(response: QueryDocumentsResponse): diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py new file mode 100644 index 000000000..b53f8cd32 --- /dev/null +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +from typing import List + +from llama_models.llama3.api.datatypes import InterleavedTextMedia + +from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore + +EMBEDDING_MODELS = {} + + +log = logging.getLogger(__name__) + + +class SentenceTransformerEmbeddingMixin: + model_store: ModelStore + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + model = await self.model_store.get_model(model_id) + embedding_model = self._load_sentence_transformer_model( + model.provider_resource_id + ) + embeddings = embedding_model.encode(contents) + return EmbeddingsResponse(embeddings=embeddings) + + def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": + global EMBEDDING_MODELS + + loaded_model = EMBEDDING_MODELS.get(model) + if loaded_model is not None: + return loaded_model + + log.info(f"Loading sentence transformer for {model}...") + from sentence_transformers import SentenceTransformer + + loaded_model = SentenceTransformer(model) + EMBEDDING_MODELS[model] = loaded_model + return loaded_model diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 8dbfab14a..be2642cdb 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -9,6 +9,7 @@ from typing import List, Optional from llama_models.sku_list import all_registered_models +from llama_stack.apis.models.models import ModelType from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference import ( @@ -77,7 +78,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate): return None async def register_model(self, model: Model) -> Model: - provider_resource_id = self.get_provider_model_id(model.provider_resource_id) + if model.model_type == ModelType.embedding_model: + # embedding models are always registered by their provider model id and does not need to be mapped to a llama model + provider_resource_id = model.provider_resource_id + else: + provider_resource_id = self.get_provider_model_id( + model.provider_resource_id + ) if provider_resource_id: model.provider_resource_id = provider_resource_id else: diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index eb83aa671..cebe897bc 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.providers.datatypes import Api log = logging.getLogger(__name__) -ALL_MINILM_L6_V2_DIMENSION = 384 - -EMBEDDING_MODELS = {} - - -def get_embedding_model(model: str) -> "SentenceTransformer": - global EMBEDDING_MODELS - - loaded_model = EMBEDDING_MODELS.get(model) - if loaded_model is not None: - return loaded_model - - log.info(f"Loading sentence transformer for {model}...") - from sentence_transformers import SentenceTransformer - - loaded_model = SentenceTransformer(model) - EMBEDDING_MODELS[model] = loaded_model - return loaded_model - def parse_pdf(data: bytes) -> str: # For PDF and DOC/DOCX files, we can't reliably convert to string @@ -166,12 +148,12 @@ class EmbeddingIndex(ABC): class BankWithIndex: bank: VectorMemoryBank index: EmbeddingIndex + inference_api: Api.inference async def insert_documents( self, documents: List[MemoryBankDocument], ) -> None: - model = get_embedding_model(self.bank.embedding_model) for doc in documents: content = await content_from_doc(doc) chunks = make_overlapped_chunks( @@ -183,7 +165,10 @@ class BankWithIndex: ) if not chunks: continue - embeddings = model.encode([x.content for x in chunks]).astype(np.float32) + embeddings_response = await self.inference_api.embeddings( + self.bank.embedding_model, [x.content for x in chunks] + ) + embeddings = np.array(embeddings_response.embeddings) await self.index.add_chunks(chunks, embeddings) @@ -208,6 +193,8 @@ class BankWithIndex: else: query_str = _process(query) - model = get_embedding_model(self.bank.embedding_model) - query_vector = model.encode([query_str])[0].astype(np.float32) + embeddings_response = await self.inference_api.embeddings( + self.bank.embedding_model, [query_str] + ) + query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) return await self.index.query(query_vector, k, score_threshold)