diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 8e89bcc72..241497050 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -53,8 +53,6 @@ class ShieldsProtocolPrivate(Protocol): class MemoryBanksProtocolPrivate(Protocol): - async def list_memory_banks(self) -> List[MemoryBank]: ... - async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ... async def unregister_memory_bank(self, memory_bank_id: str) -> None: ... diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index ff0926108..178cae574 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -53,7 +53,7 @@ def available_providers() -> List[ProviderSpec]: adapter_type="chromadb", pip_packages=EMBEDDING_DEPS + ["chromadb-client"], module="llama_stack.providers.remote.memory.chroma", - config_class="llama_stack.distribution.datatypes.RemoteProviderConfig", + config_class="llama_stack.providers.remote.memory.chroma.ChromaConfig", ), ), remote_provider_spec( diff --git a/llama_stack/providers/remote/memory/chroma/__init__.py b/llama_stack/providers/remote/memory/chroma/__init__.py index dfd5c5696..ee97d9705 100644 --- a/llama_stack/providers/remote/memory/chroma/__init__.py +++ b/llama_stack/providers/remote/memory/chroma/__init__.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.distribution.datatypes import RemoteProviderConfig +from .config import ChromaConfig -async def get_adapter_impl(config: RemoteProviderConfig, _deps): +async def get_adapter_impl(config: ChromaConfig, _deps): from .chroma import ChromaMemoryAdapter - impl = ChromaMemoryAdapter(config.url) + impl = ChromaMemoryAdapter(config) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index 207f6b54d..1b99c044d 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +import asyncio import json import logging from typing import List @@ -12,8 +12,6 @@ from urllib.parse import urlparse import chromadb from numpy.typing import NDArray -from pydantic import parse_obj_as - from llama_stack.apis.memory import * # noqa: F403 from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate @@ -21,12 +19,23 @@ from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, ) +from .config import ChromaConfig log = logging.getLogger(__name__) +ChromaClientType = Union[chromadb.AsyncHttpClient, chromadb.PersistentClient] + + +# this is a helper to allow us to use async and non-async chroma clients interchangeably +async def maybe_await(result): + if asyncio.iscoroutine(result): + return await result + return result + + class ChromaIndex(EmbeddingIndex): - def __init__(self, client: chromadb.AsyncHttpClient, collection): + def __init__(self, client: ChromaClientType, collection): self.client = client self.collection = collection @@ -35,19 +44,23 @@ class ChromaIndex(EmbeddingIndex): embeddings ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" - await self.collection.add( - documents=[chunk.json() for chunk in chunks], - embeddings=embeddings, - ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)], + await maybe_await( + self.collection.add( + documents=[chunk.model_dump_json() for chunk in chunks], + embeddings=embeddings, + ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)], + ) ) async def query( self, embedding: NDArray, k: int, score_threshold: float ) -> QueryDocumentsResponse: - results = await self.collection.query( - query_embeddings=[embedding.tolist()], - n_results=k, - include=["documents", "distances"], + results = await maybe_await( + self.collection.query( + query_embeddings=[embedding.tolist()], + n_results=k, + include=["documents", "distances"], + ) ) distances = results["distances"][0] documents = results["documents"][0] @@ -68,12 +81,12 @@ class ChromaIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) async def delete(self): - await self.client.delete_collection(self.collection.name) + await maybe_await(self.client.delete_collection(self.collection.name)) class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): - def __init__(self, url: str) -> None: - log.info(f"Initializing ChromaMemoryAdapter with url: {url}") + def __init__(self, config: ChromaConfig) -> None: + log.info(f"Initializing ChromaMemoryAdapter with url: {config.url}") url = url.rstrip("/") parsed = urlparse(url) @@ -88,8 +101,12 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def initialize(self) -> None: try: - log.info(f"Connecting to Chroma server at: {self.host}:{self.port}") - self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port) + if self.config.url: + log.info(f"Connecting to Chroma server at: {self.config.url}") + self.client = await chromadb.AsyncHttpClient(url=self.config.url) + else: + log.info(f"Connecting to Chroma local db at: {self.config.db_path}") + self.client = chromadb.PersistentClient(path=self.config.db_path) except Exception as e: log.exception("Could not connect to Chroma server") raise RuntimeError("Could not connect to Chroma server") from e @@ -105,33 +122,17 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): memory_bank.memory_bank_type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.memory_bank_type}" - collection = await self.client.get_or_create_collection( - name=memory_bank.identifier, - metadata={"bank": memory_bank.model_dump_json()}, + collection = await maybe_await( + self.client.get_or_create_collection( + name=memory_bank.identifier, + metadata={"bank": memory_bank.model_dump_json()}, + ) ) bank_index = BankWithIndex( bank=memory_bank, index=ChromaIndex(self.client, collection) ) self.cache[memory_bank.identifier] = bank_index - async def list_memory_banks(self) -> List[MemoryBank]: - collections = await self.client.list_collections() - for collection in collections: - try: - data = json.loads(collection.metadata["bank"]) - bank = parse_obj_as(VectorMemoryBank, data) - except Exception: - log.exception(f"Failed to parse bank: {collection.metadata}") - continue - - index = BankWithIndex( - bank=bank, - index=ChromaIndex(self.client, collection), - ) - self.cache[bank.identifier] = index - - return [i.bank for i in self.cache.values()] - async def unregister_memory_bank(self, memory_bank_id: str) -> None: await self.cache[memory_bank_id].index.delete() del self.cache[memory_bank_id] @@ -163,7 +164,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): bank = await self.memory_bank_store.get_memory_bank(bank_id) if not bank: raise ValueError(f"Bank {bank_id} not found in Llama Stack") - collection = await self.client.get_collection(bank_id) + collection = await maybe_await(self.client.get_collection(bank_id)) if not collection: raise ValueError(f"Bank {bank_id} not found in Chroma") index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection)) diff --git a/llama_stack/providers/remote/memory/chroma/config.py b/llama_stack/providers/remote/memory/chroma/config.py new file mode 100644 index 000000000..5a5e05dfb --- /dev/null +++ b/llama_stack/providers/remote/memory/chroma/config.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, model_validator + + +@json_schema_type +class ChromaConfig(BaseModel): + # You can either specify the url of the chroma server or the path to the local db + url: Optional[str] = None + db_path: Optional[str] = None + + @model_validator(mode="after") + def check_url_or_db_path(self): + if not (self.url or self.db_path): + raise ValueError("Either url or db_path must be specified") + + @classmethod + def sample_config(cls) -> Dict[str, Any]: + return {"url": "{env.CHROMADB_URL}"} diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index d77de7b41..9ec76e8ca 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -185,17 +185,6 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): await self.cache[memory_bank_id].index.delete() del self.cache[memory_bank_id] - async def list_memory_banks(self) -> List[MemoryBank]: - banks = load_models(self.cursor, VectorMemoryBank) - for bank in banks: - if bank.identifier not in self.cache: - index = BankWithIndex( - bank=bank, - index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), - ) - self.cache[bank.identifier] = index - return banks - async def insert_documents( self, bank_id: str, diff --git a/llama_stack/providers/remote/memory/qdrant/qdrant.py b/llama_stack/providers/remote/memory/qdrant/qdrant.py index be370eec9..a9badbd6a 100644 --- a/llama_stack/providers/remote/memory/qdrant/qdrant.py +++ b/llama_stack/providers/remote/memory/qdrant/qdrant.py @@ -127,11 +127,6 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): self.cache[memory_bank.identifier] = index - async def list_memory_banks(self) -> List[MemoryBank]: - # Qdrant doesn't have collection level metadata to store the bank properties - # So we only return from the cache value - return [i.bank for i in self.cache.values()] - async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: if bank_id in self.cache: return self.cache[bank_id] diff --git a/llama_stack/providers/remote/memory/sample/sample.py b/llama_stack/providers/remote/memory/sample/sample.py index 3431b87d5..09ea2f32c 100644 --- a/llama_stack/providers/remote/memory/sample/sample.py +++ b/llama_stack/providers/remote/memory/sample/sample.py @@ -14,7 +14,7 @@ class SampleMemoryImpl(Memory): def __init__(self, config: SampleConfig): self.config = config - async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: + async def register_memory_bank(self, memory_bank: MemoryBank) -> None: # these are the memory banks the Llama Stack will use to route requests to this provider # perform validation here if necessary pass diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py index f8fba5c0b..f05fc663e 100644 --- a/llama_stack/providers/remote/memory/weaviate/weaviate.py +++ b/llama_stack/providers/remote/memory/weaviate/weaviate.py @@ -141,13 +141,6 @@ class WeaviateMemoryAdapter( ) self.cache[memory_bank.identifier] = index - async def list_memory_banks(self) -> List[MemoryBank]: - # TODO: right now the Llama Stack is the source of truth for these banks. That is - # not ideal. It should be Weaviate which is the source of truth. Unfortunately, - # list() happens at Stack startup when the Weaviate client (credentials) is not - # yet available. We need to figure out a way to make this work. - return [i.bank for i in self.cache.values()] - async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: if bank_id in self.cache: return self.cache[bank_id]