diff --git a/llama_stack/providers/inline/memory/chroma/__init__.py b/llama_stack/providers/inline/memory/chroma/__init__.py new file mode 100644 index 000000000..44279abd1 --- /dev/null +++ b/llama_stack/providers/inline/memory/chroma/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import ChromaInlineImplConfig + + +async def get_provider_impl(config: ChromaInlineImplConfig, _deps): + from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter + + impl = ChromaMemoryAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/memory/chroma/config.py b/llama_stack/providers/inline/memory/chroma/config.py new file mode 100644 index 000000000..efbd77faf --- /dev/null +++ b/llama_stack/providers/inline/memory/chroma/config.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from pydantic import BaseModel + + +class ChromaInlineImplConfig(BaseModel): + db_path: str + + @classmethod + def sample_config(cls) -> Dict[str, Any]: + return {"db_path": "{env.CHROMADB_PATH}"} diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index 178cae574..c52aba6c6 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -53,9 +53,16 @@ def available_providers() -> List[ProviderSpec]: adapter_type="chromadb", pip_packages=EMBEDDING_DEPS + ["chromadb-client"], module="llama_stack.providers.remote.memory.chroma", - config_class="llama_stack.providers.remote.memory.chroma.ChromaConfig", + config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig", ), ), + InlineProviderSpec( + api=Api.memory, + provider_type="inline::chromadb", + pip_packages=EMBEDDING_DEPS + ["chromadb"], + module="llama_stack.providers.inline.memory.chroma", + config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig", + ), remote_provider_spec( Api.memory, AdapterSpec( diff --git a/llama_stack/providers/remote/memory/chroma/__init__.py b/llama_stack/providers/remote/memory/chroma/__init__.py index ee97d9705..63e9eae7d 100644 --- a/llama_stack/providers/remote/memory/chroma/__init__.py +++ b/llama_stack/providers/remote/memory/chroma/__init__.py @@ -4,10 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import ChromaConfig +from .config import ChromaRemoteImplConfig -async def get_adapter_impl(config: ChromaConfig, _deps): +async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps): from .chroma import ChromaMemoryAdapter impl = ChromaMemoryAdapter(config) diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index 1b99c044d..3fe6aa938 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -7,7 +7,6 @@ import asyncio import json import logging from typing import List -from urllib.parse import urlparse import chromadb from numpy.typing import NDArray @@ -15,11 +14,12 @@ from numpy.typing import NDArray from llama_stack.apis.memory import * # noqa: F403 from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, ) -from .config import ChromaConfig +from .config import ChromaRemoteImplConfig log = logging.getLogger(__name__) @@ -85,31 +85,21 @@ class ChromaIndex(EmbeddingIndex): class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): - def __init__(self, config: ChromaConfig) -> None: - log.info(f"Initializing ChromaMemoryAdapter with url: {config.url}") - url = url.rstrip("/") - parsed = urlparse(url) - - if parsed.path and parsed.path != "/": - raise ValueError("URL should not contain a path") - - self.host = parsed.hostname - self.port = parsed.port - + def __init__( + self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig] + ) -> None: + log.info(f"Initializing ChromaMemoryAdapter with url: {config}") + self.config = config self.client = None self.cache = {} async def initialize(self) -> None: - try: - if self.config.url: - log.info(f"Connecting to Chroma server at: {self.config.url}") - self.client = await chromadb.AsyncHttpClient(url=self.config.url) - else: - log.info(f"Connecting to Chroma local db at: {self.config.db_path}") - self.client = chromadb.PersistentClient(path=self.config.db_path) - except Exception as e: - log.exception("Could not connect to Chroma server") - raise RuntimeError("Could not connect to Chroma server") from e + if isinstance(self.config, ChromaRemoteImplConfig): + log.info(f"Connecting to Chroma server at: {self.config.url}") + self.client = await chromadb.AsyncHttpClient(url=self.config.url) + else: + log.info(f"Connecting to Chroma local db at: {self.config.db_path}") + self.client = chromadb.PersistentClient(path=self.config.db_path) async def shutdown(self) -> None: pass diff --git a/llama_stack/providers/remote/memory/chroma/config.py b/llama_stack/providers/remote/memory/chroma/config.py index 5a5e05dfb..68ca2c967 100644 --- a/llama_stack/providers/remote/memory/chroma/config.py +++ b/llama_stack/providers/remote/memory/chroma/config.py @@ -4,22 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any, Dict -from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, model_validator +from pydantic import BaseModel -@json_schema_type -class ChromaConfig(BaseModel): - # You can either specify the url of the chroma server or the path to the local db - url: Optional[str] = None - db_path: Optional[str] = None - - @model_validator(mode="after") - def check_url_or_db_path(self): - if not (self.url or self.db_path): - raise ValueError("Either url or db_path must be specified") +class ChromaRemoteImplConfig(BaseModel): + url: str @classmethod def sample_config(cls) -> Dict[str, Any]: diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index c9559b61c..cc57bb916 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,8 +10,10 @@ import tempfile import pytest import pytest_asyncio -from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig +from llama_stack.distribution.datatypes import Api, Provider +from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig +from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.tests.resolver import construct_stack_for_test @@ -79,15 +81,21 @@ def memory_weaviate() -> ProviderFixture: @pytest.fixture(scope="session") def memory_chroma() -> ProviderFixture: + url = os.getenv("CHROMA_URL") + if url: + config = ChromaRemoteImplConfig(url=url) + provider_type = "remote::chromadb" + else: + if not os.getenv("CHROMA_DB_PATH"): + raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set") + config = ChromaInlineImplConfig(db_path=os.getenv("CHROMA_DB_PATH")) + provider_type = "inline::chromadb" return ProviderFixture( providers=[ Provider( provider_id="chroma", - provider_type="remote::chromadb", - config=RemoteProviderConfig( - host=get_env_or_fail("CHROMA_HOST"), - port=get_env_or_fail("CHROMA_PORT"), - ).model_dump(), + provider_type=provider_type, + config=config.model_dump(), ) ] )