implement embedding generation in supported inference providers (#589)

This PR adds the ability to generate embeddings in all supported
inference providers.

```
pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py -k "bedrock" --inference-model="amazon.titan-embed-text-v2:0"  --env EMBEDDING_DIMENSION=1024

 pytest -v -s -k "vllm"  --inferrence-model="intfloat/e5-mistral-7b-instruct"  llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=4096  --env VLLM_URL="http://localhost:9798/v1"

pytest -v -s --inference-model="nomic-ai/nomic-embed-text-v1.5"  llama_stack/providers/tests/inference/test_embeddings.py  -k "fireworks"  --env FIREWORKS_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=128

pytest -v -s --inference-model="togethercomputer/m2-bert-80M-2k-retrieval"  llama_stack/providers/tests/inference/test_embeddings.py  -k "together"  --env TOGETHER_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=768

pytest -v -s -k "ollama"  --inference-model="all-minilm:v8"  llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384

 torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="sentence-transformers/all-MiniLM-L6-v2"  llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384

```
This commit is contained in:
Dinesh Yeduguru 2024-12-12 11:17:39 -08:00
parent 6a23f24ee0
commit d362d2d740
32 changed files with 597 additions and 143 deletions

View file

@ -4,12 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from .config import ChromaRemoteImplConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps):
async def get_adapter_impl(config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]):
from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config)
impl = ChromaMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -13,8 +13,7 @@ import chromadb
from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
@ -87,10 +86,14 @@ class ChromaIndex(EmbeddingIndex):
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(
self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig]
self,
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
inference_api: Api.inference,
) -> None:
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.client = None
self.cache = {}
@ -127,10 +130,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
metadata={"bank": memory_bank.model_dump_json()},
)
)
bank_index = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank, ChromaIndex(self.client, collection), self.inference_api
)
self.cache[memory_bank.identifier] = bank_index
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
@ -166,6 +168,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
collection = await maybe_await(self.client.get_collection(bank_id))
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
index = BankWithIndex(
bank, ChromaIndex(self.client, collection), self.inference_api
)
self.cache[bank_id] = index
return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import PGVectorConfig
async def get_adapter_impl(config: PGVectorConfig, _deps):
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorMemoryAdapter
impl = PGVectorMemoryAdapter(config)
impl = PGVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -16,9 +16,9 @@ from pydantic import BaseModel, parse_obj_as
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,
)
@ -120,8 +120,9 @@ class PGVectorIndex(EmbeddingIndex):
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: PGVectorConfig) -> None:
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.cursor = None
self.conn = None
self.cache = {}
@ -160,27 +161,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def shutdown(self) -> None:
pass
async def register_memory_bank(
self,
memory_bank: MemoryBank,
) -> None:
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
upsert_models(
self.cursor,
[
(memory_bank.identifier, memory_bank),
],
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank, index, self.inference_api
)
index = BankWithIndex(
bank=memory_bank,
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[memory_bank.identifier] = index
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
@ -203,14 +194,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
self.inference_api = inference_api
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return index
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
return self.cache[bank_id]

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import QdrantConfig
async def get_adapter_impl(config: QdrantConfig, _deps):
async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorMemoryAdapter
impl = QdrantVectorMemoryAdapter(config)
impl = QdrantVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -101,10 +101,11 @@ class QdrantIndex(EmbeddingIndex):
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: QdrantConfig) -> None:
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
self.cache = {}
self.inference_api = inference_api
async def initialize(self) -> None:
pass
@ -123,6 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = BankWithIndex(
bank=memory_bank,
index=QdrantIndex(self.client, memory_bank.identifier),
inference_api=self.inference_api,
)
self.cache[memory_bank.identifier] = index
@ -138,6 +140,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = BankWithIndex(
bank=bank,
index=QdrantIndex(client=self.client, collection_name=bank_id),
inference_api=self.inference_api,
)
self.cache[bank_id] = index
return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
async def get_adapter_impl(config: WeaviateConfig, _deps):
async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config)
impl = WeaviateMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -12,10 +12,11 @@ import weaviate
import weaviate.classes as wvc
from numpy.typing import NDArray
from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
@ -80,12 +81,21 @@ class WeaviateIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: List[str]) -> None:
collection = self.client.collections.get(self.collection_name)
collection.data.delete_many(
where=Filter.by_property("id").contains_any(chunk_ids)
)
class WeaviateMemoryAdapter(
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
Memory,
NeedsRequestProviderData,
MemoryBanksProtocolPrivate,
):
def __init__(self, config: WeaviateConfig) -> None:
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.client_cache = {}
self.cache = {}
@ -117,7 +127,7 @@ class WeaviateMemoryAdapter(
memory_bank: MemoryBank,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
client = self._get_client()
@ -135,11 +145,11 @@ class WeaviateMemoryAdapter(
],
)
index = BankWithIndex(
bank=memory_bank,
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank,
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
self.inference_api,
)
self.cache[memory_bank.identifier] = index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
@ -156,6 +166,7 @@ class WeaviateMemoryAdapter(
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(client=client, collection_name=bank_id),
inference_api=self.inference_api,
)
self.cache[bank_id] = index
return index