Allow using an "inline" version of Chroma using PersistentClient (#567)

The same code is used (inside providers/remote/memory/chroma/chroma.py)
but it is driven by separate configurations and changes which Chroma
client to use. Note that the dependencies are separate
(`chromadb-client` vs `chromadb` -- the latter is a _much_ heavier
package.)

```
pytest -s -v -m chroma memory/test_memory.py --env CHROMA_DB_PATH=/tmp/chroma_test
pytest -s -v -m chroma memory/test_memory.py --env CHROMA_URL=http://localhost:6001
```
This commit is contained in:
Ashwin Bharambe 2024-12-11 16:02:04 -08:00 committed by GitHub
parent 41487e6ed1
commit b7cb06f004
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 127 additions and 88 deletions

View file

@ -53,8 +53,6 @@ class ShieldsProtocolPrivate(Protocol):
class MemoryBanksProtocolPrivate(Protocol): class MemoryBanksProtocolPrivate(Protocol):
async def list_memory_banks(self) -> List[MemoryBank]: ...
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ... async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ... async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import ChromaInlineImplConfig
async def get_provider_impl(config: ChromaInlineImplConfig, _deps):
from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class ChromaInlineImplConfig(BaseModel):
db_path: str
@classmethod
def sample_config(cls) -> Dict[str, Any]:
return {"db_path": "{env.CHROMADB_PATH}"}

View file

@ -53,9 +53,16 @@ def available_providers() -> List[ProviderSpec]:
adapter_type="chromadb", adapter_type="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"], pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
module="llama_stack.providers.remote.memory.chroma", module="llama_stack.providers.remote.memory.chroma",
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig", config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
), ),
), ),
InlineProviderSpec(
api=Api.memory,
provider_type="inline::chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb"],
module="llama_stack.providers.inline.memory.chroma",
config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig",
),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
AdapterSpec( AdapterSpec(

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.distribution.datatypes import RemoteProviderConfig from .config import ChromaRemoteImplConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps): async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps):
from .chroma import ChromaMemoryAdapter from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config.url) impl = ChromaMemoryAdapter(config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -3,7 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import json import json
import logging import logging
from typing import List from typing import List
@ -12,21 +12,31 @@ from urllib.parse import urlparse
import chromadb import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
from pydantic import parse_obj_as
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
) )
from .config import ChromaRemoteImplConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ChromaClientType = Union[chromadb.AsyncHttpClient, chromadb.PersistentClient]
# this is a helper to allow us to use async and non-async chroma clients interchangeably
async def maybe_await(result):
if asyncio.iscoroutine(result):
return await result
return result
class ChromaIndex(EmbeddingIndex): class ChromaIndex(EmbeddingIndex):
def __init__(self, client: chromadb.AsyncHttpClient, collection): def __init__(self, client: ChromaClientType, collection):
self.client = client self.client = client
self.collection = collection self.collection = collection
@ -35,19 +45,23 @@ class ChromaIndex(EmbeddingIndex):
embeddings embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
await self.collection.add( await maybe_await(
documents=[chunk.json() for chunk in chunks], self.collection.add(
embeddings=embeddings, documents=[chunk.model_dump_json() for chunk in chunks],
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)], embeddings=embeddings,
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
)
) )
async def query( async def query(
self, embedding: NDArray, k: int, score_threshold: float self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
results = await self.collection.query( results = await maybe_await(
query_embeddings=[embedding.tolist()], self.collection.query(
n_results=k, query_embeddings=[embedding.tolist()],
include=["documents", "distances"], n_results=k,
include=["documents", "distances"],
)
) )
distances = results["distances"][0] distances = results["distances"][0]
documents = results["documents"][0] documents = results["documents"][0]
@ -68,31 +82,33 @@ class ChromaIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self): async def delete(self):
await self.client.delete_collection(self.collection.name) await maybe_await(self.client.delete_collection(self.collection.name))
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str) -> None: def __init__(
log.info(f"Initializing ChromaMemoryAdapter with url: {url}") self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig]
url = url.rstrip("/") ) -> None:
parsed = urlparse(url) log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
self.config = config
if parsed.path and parsed.path != "/":
raise ValueError("URL should not contain a path")
self.host = parsed.hostname
self.port = parsed.port
self.client = None self.client = None
self.cache = {} self.cache = {}
async def initialize(self) -> None: async def initialize(self) -> None:
try: if isinstance(self.config, ChromaRemoteImplConfig):
log.info(f"Connecting to Chroma server at: {self.host}:{self.port}") log.info(f"Connecting to Chroma server at: {self.config.url}")
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port) url = self.config.url.rstrip("/")
except Exception as e: parsed = urlparse(url)
log.exception("Could not connect to Chroma server")
raise RuntimeError("Could not connect to Chroma server") from e if parsed.path and parsed.path != "/":
raise ValueError("URL should not contain a path")
self.client = await chromadb.AsyncHttpClient(
host=parsed.hostname, port=parsed.port
)
else:
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
self.client = chromadb.PersistentClient(path=self.config.db_path)
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
@ -105,33 +121,17 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
memory_bank.memory_bank_type == MemoryBankType.vector.value memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}" ), f"Only vector banks are supported {memory_bank.memory_bank_type}"
collection = await self.client.get_or_create_collection( collection = await maybe_await(
name=memory_bank.identifier, self.client.get_or_create_collection(
metadata={"bank": memory_bank.model_dump_json()}, name=memory_bank.identifier,
metadata={"bank": memory_bank.model_dump_json()},
)
) )
bank_index = BankWithIndex( bank_index = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection) bank=memory_bank, index=ChromaIndex(self.client, collection)
) )
self.cache[memory_bank.identifier] = bank_index self.cache[memory_bank.identifier] = bank_index
async def list_memory_banks(self) -> List[MemoryBank]:
collections = await self.client.list_collections()
for collection in collections:
try:
data = json.loads(collection.metadata["bank"])
bank = parse_obj_as(VectorMemoryBank, data)
except Exception:
log.exception(f"Failed to parse bank: {collection.metadata}")
continue
index = BankWithIndex(
bank=bank,
index=ChromaIndex(self.client, collection),
)
self.cache[bank.identifier] = index
return [i.bank for i in self.cache.values()]
async def unregister_memory_bank(self, memory_bank_id: str) -> None: async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete() await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id] del self.cache[memory_bank_id]
@ -163,7 +163,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
bank = await self.memory_bank_store.get_memory_bank(bank_id) bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank: if not bank:
raise ValueError(f"Bank {bank_id} not found in Llama Stack") raise ValueError(f"Bank {bank_id} not found in Llama Stack")
collection = await self.client.get_collection(bank_id) collection = await maybe_await(self.client.get_collection(bank_id))
if not collection: if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma") raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection)) index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class ChromaRemoteImplConfig(BaseModel):
url: str
@classmethod
def sample_config(cls) -> Dict[str, Any]:
return {"url": "{env.CHROMADB_URL}"}

View file

@ -185,17 +185,6 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
await self.cache[memory_bank_id].index.delete() await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id] del self.cache[memory_bank_id]
async def list_memory_banks(self) -> List[MemoryBank]:
banks = load_models(self.cursor, VectorMemoryBank)
for bank in banks:
if bank.identifier not in self.cache:
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank.identifier] = index
return banks
async def insert_documents( async def insert_documents(
self, self,
bank_id: str, bank_id: str,

View file

@ -127,11 +127,6 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]:
# Qdrant doesn't have collection level metadata to store the bank properties
# So we only return from the cache value
return [i.bank for i in self.cache.values()]
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]

View file

@ -14,7 +14,7 @@ class SampleMemoryImpl(Memory):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
# these are the memory banks the Llama Stack will use to route requests to this provider # these are the memory banks the Llama Stack will use to route requests to this provider
# perform validation here if necessary # perform validation here if necessary
pass pass

View file

@ -141,13 +141,6 @@ class WeaviateMemoryAdapter(
) )
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]:
# TODO: right now the Llama Stack is the source of truth for these banks. That is
# not ideal. It should be Weaviate which is the source of truth. Unfortunately,
# list() happens at Stack startup when the Weaviate client (credentials) is not
# yet available. We need to figure out a way to make this work.
return [i.bank for i in self.cache.values()]
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]

View file

@ -10,8 +10,10 @@ import tempfile
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.tests.resolver import construct_stack_for_test
@ -79,15 +81,21 @@ def memory_weaviate() -> ProviderFixture:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def memory_chroma() -> ProviderFixture: def memory_chroma() -> ProviderFixture:
url = os.getenv("CHROMA_URL")
if url:
config = ChromaRemoteImplConfig(url=url)
provider_type = "remote::chromadb"
else:
if not os.getenv("CHROMA_DB_PATH"):
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
config = ChromaInlineImplConfig(db_path=os.getenv("CHROMA_DB_PATH"))
provider_type = "inline::chromadb"
return ProviderFixture( return ProviderFixture(
providers=[ providers=[
Provider( Provider(
provider_id="chroma", provider_id="chroma",
provider_type="remote::chromadb", provider_type=provider_type,
config=RemoteProviderConfig( config=config.model_dump(),
host=get_env_or_fail("CHROMA_HOST"),
port=get_env_or_fail("CHROMA_PORT"),
).model_dump(),
) )
] ]
) )