diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 8e89bcc72..241497050 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -53,8 +53,6 @@ class ShieldsProtocolPrivate(Protocol): class MemoryBanksProtocolPrivate(Protocol): - async def list_memory_banks(self) -> List[MemoryBank]: ... - async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ... async def unregister_memory_bank(self, memory_bank_id: str) -> None: ... diff --git a/llama_stack/providers/inline/memory/chroma/__init__.py b/llama_stack/providers/inline/memory/chroma/__init__.py new file mode 100644 index 000000000..44279abd1 --- /dev/null +++ b/llama_stack/providers/inline/memory/chroma/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import ChromaInlineImplConfig + + +async def get_provider_impl(config: ChromaInlineImplConfig, _deps): + from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter + + impl = ChromaMemoryAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/memory/chroma/config.py b/llama_stack/providers/inline/memory/chroma/config.py new file mode 100644 index 000000000..efbd77faf --- /dev/null +++ b/llama_stack/providers/inline/memory/chroma/config.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from pydantic import BaseModel + + +class ChromaInlineImplConfig(BaseModel): + db_path: str + + @classmethod + def sample_config(cls) -> Dict[str, Any]: + return {"db_path": "{env.CHROMADB_PATH}"} diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index ff0926108..c52aba6c6 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -53,9 +53,16 @@ def available_providers() -> List[ProviderSpec]: adapter_type="chromadb", pip_packages=EMBEDDING_DEPS + ["chromadb-client"], module="llama_stack.providers.remote.memory.chroma", - config_class="llama_stack.distribution.datatypes.RemoteProviderConfig", + config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig", ), ), + InlineProviderSpec( + api=Api.memory, + provider_type="inline::chromadb", + pip_packages=EMBEDDING_DEPS + ["chromadb"], + module="llama_stack.providers.inline.memory.chroma", + config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig", + ), remote_provider_spec( Api.memory, AdapterSpec( diff --git a/llama_stack/providers/remote/memory/chroma/__init__.py b/llama_stack/providers/remote/memory/chroma/__init__.py index dfd5c5696..63e9eae7d 100644 --- a/llama_stack/providers/remote/memory/chroma/__init__.py +++ b/llama_stack/providers/remote/memory/chroma/__init__.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.distribution.datatypes import RemoteProviderConfig +from .config import ChromaRemoteImplConfig -async def get_adapter_impl(config: RemoteProviderConfig, _deps): +async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps): from .chroma import ChromaMemoryAdapter - impl = ChromaMemoryAdapter(config.url) + impl = ChromaMemoryAdapter(config) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index 207f6b54d..f4fb50a7c 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +import asyncio import json import logging from typing import List @@ -12,21 +12,31 @@ from urllib.parse import urlparse import chromadb from numpy.typing import NDArray -from pydantic import parse_obj_as - from llama_stack.apis.memory import * # noqa: F403 from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, ) +from .config import ChromaRemoteImplConfig log = logging.getLogger(__name__) +ChromaClientType = Union[chromadb.AsyncHttpClient, chromadb.PersistentClient] + + +# this is a helper to allow us to use async and non-async chroma clients interchangeably +async def maybe_await(result): + if asyncio.iscoroutine(result): + return await result + return result + + class ChromaIndex(EmbeddingIndex): - def __init__(self, client: chromadb.AsyncHttpClient, collection): + def __init__(self, client: ChromaClientType, collection): self.client = client self.collection = collection @@ -35,19 +45,23 @@ class ChromaIndex(EmbeddingIndex): embeddings ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" - await self.collection.add( - documents=[chunk.json() for chunk in chunks], - embeddings=embeddings, - ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)], + await maybe_await( + self.collection.add( + documents=[chunk.model_dump_json() for chunk in chunks], + embeddings=embeddings, + ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)], + ) ) async def query( self, embedding: NDArray, k: int, score_threshold: float ) -> QueryDocumentsResponse: - results = await self.collection.query( - query_embeddings=[embedding.tolist()], - n_results=k, - include=["documents", "distances"], + results = await maybe_await( + self.collection.query( + query_embeddings=[embedding.tolist()], + n_results=k, + include=["documents", "distances"], + ) ) distances = results["distances"][0] documents = results["documents"][0] @@ -68,31 +82,33 @@ class ChromaIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) async def delete(self): - await self.client.delete_collection(self.collection.name) + await maybe_await(self.client.delete_collection(self.collection.name)) class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): - def __init__(self, url: str) -> None: - log.info(f"Initializing ChromaMemoryAdapter with url: {url}") - url = url.rstrip("/") - parsed = urlparse(url) - - if parsed.path and parsed.path != "/": - raise ValueError("URL should not contain a path") - - self.host = parsed.hostname - self.port = parsed.port - + def __init__( + self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig] + ) -> None: + log.info(f"Initializing ChromaMemoryAdapter with url: {config}") + self.config = config self.client = None self.cache = {} async def initialize(self) -> None: - try: - log.info(f"Connecting to Chroma server at: {self.host}:{self.port}") - self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port) - except Exception as e: - log.exception("Could not connect to Chroma server") - raise RuntimeError("Could not connect to Chroma server") from e + if isinstance(self.config, ChromaRemoteImplConfig): + log.info(f"Connecting to Chroma server at: {self.config.url}") + url = self.config.url.rstrip("/") + parsed = urlparse(url) + + if parsed.path and parsed.path != "/": + raise ValueError("URL should not contain a path") + + self.client = await chromadb.AsyncHttpClient( + host=parsed.hostname, port=parsed.port + ) + else: + log.info(f"Connecting to Chroma local db at: {self.config.db_path}") + self.client = chromadb.PersistentClient(path=self.config.db_path) async def shutdown(self) -> None: pass @@ -105,33 +121,17 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): memory_bank.memory_bank_type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.memory_bank_type}" - collection = await self.client.get_or_create_collection( - name=memory_bank.identifier, - metadata={"bank": memory_bank.model_dump_json()}, + collection = await maybe_await( + self.client.get_or_create_collection( + name=memory_bank.identifier, + metadata={"bank": memory_bank.model_dump_json()}, + ) ) bank_index = BankWithIndex( bank=memory_bank, index=ChromaIndex(self.client, collection) ) self.cache[memory_bank.identifier] = bank_index - async def list_memory_banks(self) -> List[MemoryBank]: - collections = await self.client.list_collections() - for collection in collections: - try: - data = json.loads(collection.metadata["bank"]) - bank = parse_obj_as(VectorMemoryBank, data) - except Exception: - log.exception(f"Failed to parse bank: {collection.metadata}") - continue - - index = BankWithIndex( - bank=bank, - index=ChromaIndex(self.client, collection), - ) - self.cache[bank.identifier] = index - - return [i.bank for i in self.cache.values()] - async def unregister_memory_bank(self, memory_bank_id: str) -> None: await self.cache[memory_bank_id].index.delete() del self.cache[memory_bank_id] @@ -163,7 +163,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): bank = await self.memory_bank_store.get_memory_bank(bank_id) if not bank: raise ValueError(f"Bank {bank_id} not found in Llama Stack") - collection = await self.client.get_collection(bank_id) + collection = await maybe_await(self.client.get_collection(bank_id)) if not collection: raise ValueError(f"Bank {bank_id} not found in Chroma") index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection)) diff --git a/llama_stack/providers/remote/memory/chroma/config.py b/llama_stack/providers/remote/memory/chroma/config.py new file mode 100644 index 000000000..68ca2c967 --- /dev/null +++ b/llama_stack/providers/remote/memory/chroma/config.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from pydantic import BaseModel + + +class ChromaRemoteImplConfig(BaseModel): + url: str + + @classmethod + def sample_config(cls) -> Dict[str, Any]: + return {"url": "{env.CHROMADB_URL}"} diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index d77de7b41..9ec76e8ca 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -185,17 +185,6 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): await self.cache[memory_bank_id].index.delete() del self.cache[memory_bank_id] - async def list_memory_banks(self) -> List[MemoryBank]: - banks = load_models(self.cursor, VectorMemoryBank) - for bank in banks: - if bank.identifier not in self.cache: - index = BankWithIndex( - bank=bank, - index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), - ) - self.cache[bank.identifier] = index - return banks - async def insert_documents( self, bank_id: str, diff --git a/llama_stack/providers/remote/memory/qdrant/qdrant.py b/llama_stack/providers/remote/memory/qdrant/qdrant.py index be370eec9..a9badbd6a 100644 --- a/llama_stack/providers/remote/memory/qdrant/qdrant.py +++ b/llama_stack/providers/remote/memory/qdrant/qdrant.py @@ -127,11 +127,6 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): self.cache[memory_bank.identifier] = index - async def list_memory_banks(self) -> List[MemoryBank]: - # Qdrant doesn't have collection level metadata to store the bank properties - # So we only return from the cache value - return [i.bank for i in self.cache.values()] - async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: if bank_id in self.cache: return self.cache[bank_id] diff --git a/llama_stack/providers/remote/memory/sample/sample.py b/llama_stack/providers/remote/memory/sample/sample.py index 3431b87d5..09ea2f32c 100644 --- a/llama_stack/providers/remote/memory/sample/sample.py +++ b/llama_stack/providers/remote/memory/sample/sample.py @@ -14,7 +14,7 @@ class SampleMemoryImpl(Memory): def __init__(self, config: SampleConfig): self.config = config - async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: + async def register_memory_bank(self, memory_bank: MemoryBank) -> None: # these are the memory banks the Llama Stack will use to route requests to this provider # perform validation here if necessary pass diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py index f8fba5c0b..f05fc663e 100644 --- a/llama_stack/providers/remote/memory/weaviate/weaviate.py +++ b/llama_stack/providers/remote/memory/weaviate/weaviate.py @@ -141,13 +141,6 @@ class WeaviateMemoryAdapter( ) self.cache[memory_bank.identifier] = index - async def list_memory_banks(self) -> List[MemoryBank]: - # TODO: right now the Llama Stack is the source of truth for these banks. That is - # not ideal. It should be Weaviate which is the source of truth. Unfortunately, - # list() happens at Stack startup when the Weaviate client (credentials) is not - # yet available. We need to figure out a way to make this work. - return [i.bank for i in self.cache.values()] - async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: if bank_id in self.cache: return self.cache[bank_id] diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index c9559b61c..cc57bb916 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,8 +10,10 @@ import tempfile import pytest import pytest_asyncio -from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig +from llama_stack.distribution.datatypes import Api, Provider +from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig +from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.tests.resolver import construct_stack_for_test @@ -79,15 +81,21 @@ def memory_weaviate() -> ProviderFixture: @pytest.fixture(scope="session") def memory_chroma() -> ProviderFixture: + url = os.getenv("CHROMA_URL") + if url: + config = ChromaRemoteImplConfig(url=url) + provider_type = "remote::chromadb" + else: + if not os.getenv("CHROMA_DB_PATH"): + raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set") + config = ChromaInlineImplConfig(db_path=os.getenv("CHROMA_DB_PATH")) + provider_type = "inline::chromadb" return ProviderFixture( providers=[ Provider( provider_id="chroma", - provider_type="remote::chromadb", - config=RemoteProviderConfig( - host=get_env_or_fail("CHROMA_HOST"), - port=get_env_or_fail("CHROMA_PORT"), - ).model_dump(), + provider_type=provider_type, + config=config.model_dump(), ) ] )