Vector store inference api (#598)

# What does this PR do?
Moves all the memory providers to use the inference API and improved the
memory tests to setup the inference stack correctly and use the
embedding models


## Test Plan
torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.2-3B-Instruct"
--embedding-model="sentence-transformers/all-MiniLM-L6-v2"
llama_stack/providers/tests/inference/test_embeddings.py --env
EMBEDDING_DIMENSION=384


pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=weaviate"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY> --env
WEAVIATE_API_KEY=foo --env WEAVIATE_CLUSTER_URL=bar
 
pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=chroma"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>--env
CHROMA_HOST=localhost --env CHROMA_PORT=8000

pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=pgvector"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env
PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0 --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>

pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=faiss"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>
This commit is contained in:
Dinesh Yeduguru 2024-12-12 11:16:54 -08:00 committed by GitHub
parent db7b26a8c9
commit 4f8b73b9e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 235 additions and 118 deletions

View file

@ -4,16 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import FaissImplConfig from .config import FaissImplConfig
async def get_provider_impl(config: FaissImplConfig, _deps): async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
from .faiss import FaissMemoryImpl from .faiss import FaissMemoryImpl
assert isinstance( assert isinstance(
config, FaissImplConfig config, FaissImplConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"
impl = FaissMemoryImpl(config) impl = FaissMemoryImpl(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -19,11 +19,10 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
) )
@ -32,7 +31,8 @@ from .config import FaissImplConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MEMORY_BANKS_PREFIX = "memory_banks:v1::" MEMORY_BANKS_PREFIX = "memory_banks:v2::"
FAISS_INDEX_PREFIX = "faiss_index:v2::"
class FaissIndex(EmbeddingIndex): class FaissIndex(EmbeddingIndex):
@ -56,7 +56,7 @@ class FaissIndex(EmbeddingIndex):
if not self.kvstore: if not self.kvstore:
return return
index_key = f"faiss_index:v1::{self.bank_id}" index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
stored_data = await self.kvstore.get(index_key) stored_data = await self.kvstore.get(index_key)
if stored_data: if stored_data:
@ -85,16 +85,25 @@ class FaissIndex(EmbeddingIndex):
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"), "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
} }
index_key = f"faiss_index:v1::{self.bank_id}" index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
await self.kvstore.set(key=index_key, value=json.dumps(data)) await self.kvstore.set(key=index_key, value=json.dumps(data))
async def delete(self): async def delete(self):
if not self.kvstore or not self.bank_id: if not self.kvstore or not self.bank_id:
return return
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}") await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
# Add dimension check
embedding_dim = (
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
)
if embedding_dim != self.index.d:
raise ValueError(
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
)
indexlen = len(self.id_by_index) indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk self.chunk_by_index[indexlen + i] = chunk
@ -124,8 +133,9 @@ class FaissIndex(EmbeddingIndex):
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: FaissImplConfig) -> None: def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api
self.cache = {} self.cache = {}
self.kvstore = None self.kvstore = None
@ -139,10 +149,11 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
for bank_data in stored_banks: for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data) bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank,
index=await FaissIndex.create( await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier bank.embedding_dimension, self.kvstore, bank.identifier
), ),
self.inference_api,
) )
self.cache[bank.identifier] = index self.cache[bank.identifier] = index
@ -166,13 +177,13 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
) )
# Store in cache # Store in cache
index = BankWithIndex( self.cache[memory_bank.identifier] = BankWithIndex(
bank=memory_bank, memory_bank,
index=await FaissIndex.create( await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
), ),
self.inference_api,
) )
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]: async def list_memory_banks(self) -> List[MemoryBank]:
return [i.bank for i in self.cache.values()] return [i.bank for i in self.cache.values()]

View file

@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.memory.faiss", module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.", deprecation_warning="Please use the `inline::faiss` provider instead.",
api_dependencies=[Api.inference],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.memory, api=Api.memory,
@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss", module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.chroma", module="llama_stack.providers.remote.memory.chroma",
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig", config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
), ),
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -64,6 +67,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.pgvector", module="llama_stack.providers.remote.memory.pgvector",
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig", config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
), ),
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -74,6 +78,7 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig", config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData", provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
), ),
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
api=Api.memory, api=Api.memory,
@ -83,6 +88,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.sample", module="llama_stack.providers.remote.memory.sample",
config_class="llama_stack.providers.remote.memory.sample.SampleConfig", config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
), ),
api_dependencies=[],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -92,5 +98,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.qdrant", module="llama_stack.providers.remote.memory.qdrant",
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig", config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
), ),
api_dependencies=[Api.inference],
), ),
] ]

View file

@ -4,12 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
async def get_adapter_impl(config: RemoteProviderConfig, _deps): async def get_adapter_impl(config: RemoteProviderConfig, deps: Dict[Api, ProviderSpec]):
from .chroma import ChromaMemoryAdapter from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config.url) impl = ChromaMemoryAdapter(config.url, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -15,8 +15,7 @@ from numpy.typing import NDArray
from pydantic import parse_obj_as from pydantic import parse_obj_as
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
@ -72,7 +71,7 @@ class ChromaIndex(EmbeddingIndex):
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str) -> None: def __init__(self, url: str, inference_api: Api.inference) -> None:
log.info(f"Initializing ChromaMemoryAdapter with url: {url}") log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/") url = url.rstrip("/")
parsed = urlparse(url) parsed = urlparse(url)
@ -82,6 +81,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
self.host = parsed.hostname self.host = parsed.hostname
self.port = parsed.port self.port = parsed.port
self.inference_api = inference_api
self.client = None self.client = None
self.cache = {} self.cache = {}
@ -109,10 +109,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
name=memory_bank.identifier, name=memory_bank.identifier,
metadata={"bank": memory_bank.model_dump_json()}, metadata={"bank": memory_bank.model_dump_json()},
) )
bank_index = BankWithIndex( self.cache[memory_bank.identifier] = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection) memory_bank, ChromaIndex(self.client, collection), self.inference_api
) )
self.cache[memory_bank.identifier] = bank_index
async def list_memory_banks(self) -> List[MemoryBank]: async def list_memory_banks(self) -> List[MemoryBank]:
collections = await self.client.list_collections() collections = await self.client.list_collections()
@ -124,11 +123,11 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
log.exception(f"Failed to parse bank: {collection.metadata}") log.exception(f"Failed to parse bank: {collection.metadata}")
continue continue
index = BankWithIndex( self.cache[bank.identifier] = BankWithIndex(
bank=bank, bank,
index=ChromaIndex(self.client, collection), ChromaIndex(self.client, collection),
self.inference_api,
) )
self.cache[bank.identifier] = index
return [i.bank for i in self.cache.values()] return [i.bank for i in self.cache.values()]
@ -166,6 +165,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
collection = await self.client.get_collection(bank_id) collection = await self.client.get_collection(bank_id)
if not collection: if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma") raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection)) index = BankWithIndex(
bank, ChromaIndex(self.client, collection), self.inference_api
)
self.cache[bank_id] = index self.cache[bank_id] = index
return index return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import PGVectorConfig from .config import PGVectorConfig
async def get_adapter_impl(config: PGVectorConfig, _deps): async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorMemoryAdapter from .pgvector import PGVectorMemoryAdapter
impl = PGVectorMemoryAdapter(config) impl = PGVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -16,9 +16,9 @@ from pydantic import BaseModel, parse_obj_as
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
) )
@ -120,8 +120,9 @@ class PGVectorIndex(EmbeddingIndex):
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: PGVectorConfig) -> None: def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api
self.cursor = None self.cursor = None
self.conn = None self.conn = None
self.cache = {} self.cache = {}
@ -160,27 +161,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_memory_bank( async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
self,
memory_bank: MemoryBank,
) -> None:
assert ( assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}" ), f"Only vector banks are supported {memory_bank.memory_bank_type}"
upsert_models( upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
self.cursor, index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
[ self.cache[memory_bank.identifier] = BankWithIndex(
(memory_bank.identifier, memory_bank), memory_bank, index, self.inference_api
],
) )
index = BankWithIndex(
bank=memory_bank,
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[memory_bank.identifier] = index
async def unregister_memory_bank(self, memory_bank_id: str) -> None: async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete() await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id] del self.cache[memory_bank_id]
@ -190,8 +181,9 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
for bank in banks: for bank in banks:
if bank.identifier not in self.cache: if bank.identifier not in self.cache:
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), PGVectorIndex(bank, bank.embedding_dimension, self.cursor),
self.inference_api,
) )
self.cache[bank.identifier] = index self.cache[bank.identifier] = index
return banks return banks
@ -214,14 +206,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params) return await index.query_documents(query, params)
self.inference_api = inference_api
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id) bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = BankWithIndex( index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
bank=bank, self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), return self.cache[bank_id]
)
self.cache[bank_id] = index
return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import QdrantConfig from .config import QdrantConfig
async def get_adapter_impl(config: QdrantConfig, _deps): async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorMemoryAdapter from .qdrant import QdrantVectorMemoryAdapter
impl = QdrantVectorMemoryAdapter(config) impl = QdrantVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -101,10 +101,11 @@ class QdrantIndex(EmbeddingIndex):
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: QdrantConfig) -> None: def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
self.cache = {} self.cache = {}
self.inference_api = inference_api
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -123,6 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = BankWithIndex( index = BankWithIndex(
bank=memory_bank, bank=memory_bank,
index=QdrantIndex(self.client, memory_bank.identifier), index=QdrantIndex(self.client, memory_bank.identifier),
inference_api=self.inference_api,
) )
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index
@ -143,6 +145,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank=bank,
index=QdrantIndex(client=self.client, collection_name=bank_id), index=QdrantIndex(client=self.client, collection_name=bank_id),
inference_api=self.inference_api,
) )
self.cache[bank_id] = index self.cache[bank_id] = index
return index return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401 from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
async def get_adapter_impl(config: WeaviateConfig, _deps): async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
from .weaviate import WeaviateMemoryAdapter from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config) impl = WeaviateMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -12,10 +12,11 @@ import weaviate
import weaviate.classes as wvc import weaviate.classes as wvc
from numpy.typing import NDArray from numpy.typing import NDArray
from weaviate.classes.init import Auth from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
@ -80,12 +81,21 @@ class WeaviateIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: List[str]) -> None:
collection = self.client.collections.get(self.collection_name)
collection.data.delete_many(
where=Filter.by_property("id").contains_any(chunk_ids)
)
class WeaviateMemoryAdapter( class WeaviateMemoryAdapter(
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate Memory,
NeedsRequestProviderData,
MemoryBanksProtocolPrivate,
): ):
def __init__(self, config: WeaviateConfig) -> None: def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api
self.client_cache = {} self.client_cache = {}
self.cache = {} self.cache = {}
@ -117,7 +127,7 @@ class WeaviateMemoryAdapter(
memory_bank: MemoryBank, memory_bank: MemoryBank,
) -> None: ) -> None:
assert ( assert (
memory_bank.memory_bank_type == MemoryBankType.vector memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}" ), f"Only vector banks are supported {memory_bank.memory_bank_type}"
client = self._get_client() client = self._get_client()
@ -135,11 +145,11 @@ class WeaviateMemoryAdapter(
], ],
) )
index = BankWithIndex( self.cache[memory_bank.identifier] = BankWithIndex(
bank=memory_bank, memory_bank,
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier), WeaviateIndex(client=client, collection_name=memory_bank.identifier),
self.inference_api,
) )
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]: async def list_memory_banks(self) -> List[MemoryBank]:
# TODO: right now the Llama Stack is the source of truth for these banks. That is # TODO: right now the Llama Stack is the source of truth for these banks. That is
@ -163,6 +173,7 @@ class WeaviateMemoryAdapter(
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank=bank,
index=WeaviateIndex(client=client, collection_name=bank_id), index=WeaviateIndex(client=client, collection_name=bank_id),
inference_api=self.inference_api,
) )
self.cache[bank_id] = index self.cache[bank_id] = index
return index return index

View file

@ -6,9 +6,65 @@
import pytest import pytest
from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import MEMORY_FIXTURES from .fixtures import MEMORY_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "meta_reference",
"memory": "faiss",
},
id="meta_reference",
marks=pytest.mark.meta_reference,
),
pytest.param(
{
"inference": "ollama",
"memory": "pgvector",
},
id="ollama",
marks=pytest.mark.ollama,
),
pytest.param(
{
"inference": "together",
"memory": "chroma",
},
id="chroma",
marks=pytest.mark.chroma,
),
pytest.param(
{
"inference": "bedrock",
"memory": "qdrant",
},
id="qdrant",
marks=pytest.mark.qdrant,
),
pytest.param(
{
"inference": "fireworks",
"memory": "weaviate",
},
id="weaviate",
marks=pytest.mark.weaviate,
),
]
def pytest_addoption(parser):
parser.addoption(
"--embedding-model",
action="store",
default=None,
help="Specify the embedding model to use for testing",
)
def pytest_configure(config): def pytest_configure(config):
for fixture_name in MEMORY_FIXTURES: for fixture_name in MEMORY_FIXTURES:
config.addinivalue_line( config.addinivalue_line(
@ -18,12 +74,22 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
if "embedding_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--embedding-model")
if not model:
raise ValueError(
"No embedding model specified. Please provide a valid embedding model."
)
params = [pytest.param(model, id="")]
metafunc.parametrize("embedding_model", params, indirect=True)
if "memory_stack" in metafunc.fixturenames: if "memory_stack" in metafunc.fixturenames:
metafunc.parametrize( available_fixtures = {
"memory_stack", "inference": INFERENCE_FIXTURES,
[ "memory": MEMORY_FIXTURES,
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) }
for fixture_name in MEMORY_FIXTURES combinations = (
], get_provider_fixture_overrides(metafunc.config, available_fixtures)
indirect=True, or DEFAULT_PROVIDER_COMBINATIONS
) )
metafunc.parametrize("memory_stack", combinations, indirect=True)

View file

@ -10,6 +10,8 @@ import tempfile
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.apis.inference import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
@ -97,14 +99,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def memory_stack(request): async def memory_stack(embedding_model, request):
fixture_name = request.param fixture_dict = request.param
fixture = request.getfixturevalue(f"memory_{fixture_name}")
providers = {}
provider_data = {}
for key in ["inference", "memory"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test( test_stack = await construct_stack_for_test(
[Api.memory], [Api.memory, Api.inference],
{"memory": fixture.providers}, providers,
fixture.provider_data, provider_data,
models=[
ModelInput(
model_id=embedding_model,
model_type=ModelType.embedding_model,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},
)
],
) )
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks] return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]

View file

@ -45,12 +45,14 @@ def sample_documents():
] ]
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank: async def register_memory_bank(
banks_impl: MemoryBanks, embedding_model: str
) -> MemoryBank:
bank_id = f"test_bank_{uuid.uuid4().hex}" bank_id = f"test_bank_{uuid.uuid4().hex}"
return await banks_impl.register_memory_bank( return await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2", embedding_model=embedding_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
class TestMemory: class TestMemory:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_banks_list(self, memory_stack): async def test_banks_list(self, memory_stack, embedding_model):
_, banks_impl = memory_stack _, banks_impl = memory_stack
# Register a test bank # Register a test bank
registered_bank = await register_memory_bank(banks_impl) registered_bank = await register_memory_bank(banks_impl, embedding_model)
try: try:
# Verify our bank shows up in list # Verify our bank shows up in list
@ -84,7 +86,7 @@ class TestMemory:
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_banks_register(self, memory_stack): async def test_banks_register(self, memory_stack, embedding_model):
_, banks_impl = memory_stack _, banks_impl = memory_stack
bank_id = f"test_bank_{uuid.uuid4().hex}" bank_id = f"test_bank_{uuid.uuid4().hex}"
@ -94,7 +96,7 @@ class TestMemory:
await banks_impl.register_memory_bank( await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2", embedding_model=embedding_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -109,7 +111,7 @@ class TestMemory:
await banks_impl.register_memory_bank( await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2", embedding_model=embedding_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -126,13 +128,15 @@ class TestMemory:
await banks_impl.unregister_memory_bank(bank_id) await banks_impl.unregister_memory_bank(bank_id)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_documents(self, memory_stack, sample_documents): async def test_query_documents(
self, memory_stack, embedding_model, sample_documents
):
memory_impl, banks_impl = memory_stack memory_impl, banks_impl = memory_stack
with pytest.raises(ValueError): with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents) await memory_impl.insert_documents("test_bank", sample_documents)
registered_bank = await register_memory_bank(banks_impl) registered_bank = await register_memory_bank(banks_impl, embedding_model)
await memory_impl.insert_documents( await memory_impl.insert_documents(
registered_bank.memory_bank_id, sample_documents registered_bank.memory_bank_id, sample_documents
) )
@ -165,13 +169,13 @@ class TestMemory:
# Test case 5: Query with threshold on similarity score # Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.2} params5 = {"score_threshold": 0.01}
response5 = await memory_impl.query_documents( response5 = await memory_impl.query_documents(
registered_bank.memory_bank_id, query5, params5 registered_bank.memory_bank_id, query5, params5
) )
assert_valid_response(response5) assert_valid_response(response5)
print("The scores are:", response5.scores) print("The scores are:", response5.scores)
assert all(score >= 0.2 for score in response5.scores) assert all(score >= 0.01 for score in response5.scores)
def assert_valid_response(response: QueryDocumentsResponse): def assert_valid_response(response: QueryDocumentsResponse):

View file

@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODELS = {}
def get_embedding_model(model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model)
if loaded_model is not None:
return loaded_model
log.info(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model
def parse_pdf(data: bytes) -> str: def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string # For PDF and DOC/DOCX files, we can't reliably convert to string
@ -166,12 +148,12 @@ class EmbeddingIndex(ABC):
class BankWithIndex: class BankWithIndex:
bank: VectorMemoryBank bank: VectorMemoryBank
index: EmbeddingIndex index: EmbeddingIndex
inference_api: Api.inference
async def insert_documents( async def insert_documents(
self, self,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ) -> None:
model = get_embedding_model(self.bank.embedding_model)
for doc in documents: for doc in documents:
content = await content_from_doc(doc) content = await content_from_doc(doc)
chunks = make_overlapped_chunks( chunks = make_overlapped_chunks(
@ -183,7 +165,10 @@ class BankWithIndex:
) )
if not chunks: if not chunks:
continue continue
embeddings = model.encode([x.content for x in chunks]).astype(np.float32) embeddings_response = await self.inference_api.embeddings(
self.bank.embedding_model, [x.content for x in chunks]
)
embeddings = np.array(embeddings_response.embeddings)
await self.index.add_chunks(chunks, embeddings) await self.index.add_chunks(chunks, embeddings)
@ -208,6 +193,8 @@ class BankWithIndex:
else: else:
query_str = _process(query) query_str = _process(query)
model = get_embedding_model(self.bank.embedding_model) embeddings_response = await self.inference_api.embeddings(
query_vector = model.encode([query_str])[0].astype(np.float32) self.bank.embedding_model, [query_str]
)
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
return await self.index.query(query_vector, k, score_threshold) return await self.index.query(query_vector, k, score_threshold)