Create a separate inline::chromadb provider

This commit is contained in:
Ashwin Bharambe 2024-12-11 14:11:08 -08:00
parent 44ab7d93fb
commit f0e045d1c8
7 changed files with 73 additions and 45 deletions

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import ChromaInlineImplConfig
async def get_provider_impl(config: ChromaInlineImplConfig, _deps):
from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class ChromaInlineImplConfig(BaseModel):
db_path: str
@classmethod
def sample_config(cls) -> Dict[str, Any]:
return {"db_path": "{env.CHROMADB_PATH}"}

View file

@ -53,9 +53,16 @@ def available_providers() -> List[ProviderSpec]:
adapter_type="chromadb", adapter_type="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"], pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
module="llama_stack.providers.remote.memory.chroma", module="llama_stack.providers.remote.memory.chroma",
config_class="llama_stack.providers.remote.memory.chroma.ChromaConfig", config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
), ),
), ),
InlineProviderSpec(
api=Api.memory,
provider_type="inline::chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb"],
module="llama_stack.providers.inline.memory.chroma",
config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig",
),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
AdapterSpec( AdapterSpec(

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .config import ChromaConfig from .config import ChromaRemoteImplConfig
async def get_adapter_impl(config: ChromaConfig, _deps): async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps):
from .chroma import ChromaMemoryAdapter from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config) impl = ChromaMemoryAdapter(config)

View file

@ -7,7 +7,6 @@ import asyncio
import json import json
import logging import logging
from typing import List from typing import List
from urllib.parse import urlparse
import chromadb import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
@ -15,11 +14,12 @@ from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
) )
from .config import ChromaConfig from .config import ChromaRemoteImplConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -85,31 +85,21 @@ class ChromaIndex(EmbeddingIndex):
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: ChromaConfig) -> None: def __init__(
log.info(f"Initializing ChromaMemoryAdapter with url: {config.url}") self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig]
url = url.rstrip("/") ) -> None:
parsed = urlparse(url) log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
self.config = config
if parsed.path and parsed.path != "/":
raise ValueError("URL should not contain a path")
self.host = parsed.hostname
self.port = parsed.port
self.client = None self.client = None
self.cache = {} self.cache = {}
async def initialize(self) -> None: async def initialize(self) -> None:
try: if isinstance(self.config, ChromaRemoteImplConfig):
if self.config.url: log.info(f"Connecting to Chroma server at: {self.config.url}")
log.info(f"Connecting to Chroma server at: {self.config.url}") self.client = await chromadb.AsyncHttpClient(url=self.config.url)
self.client = await chromadb.AsyncHttpClient(url=self.config.url) else:
else: log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
log.info(f"Connecting to Chroma local db at: {self.config.db_path}") self.client = chromadb.PersistentClient(path=self.config.db_path)
self.client = chromadb.PersistentClient(path=self.config.db_path)
except Exception as e:
log.exception("Could not connect to Chroma server")
raise RuntimeError("Could not connect to Chroma server") from e
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass

View file

@ -4,22 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, Optional from typing import Any, Dict
from llama_models.schema_utils import json_schema_type from pydantic import BaseModel
from pydantic import BaseModel, model_validator
@json_schema_type class ChromaRemoteImplConfig(BaseModel):
class ChromaConfig(BaseModel): url: str
# You can either specify the url of the chroma server or the path to the local db
url: Optional[str] = None
db_path: Optional[str] = None
@model_validator(mode="after")
def check_url_or_db_path(self):
if not (self.url or self.db_path):
raise ValueError("Either url or db_path must be specified")
@classmethod @classmethod
def sample_config(cls) -> Dict[str, Any]: def sample_config(cls) -> Dict[str, Any]:

View file

@ -10,8 +10,10 @@ import tempfile
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.tests.resolver import construct_stack_for_test
@ -79,15 +81,21 @@ def memory_weaviate() -> ProviderFixture:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def memory_chroma() -> ProviderFixture: def memory_chroma() -> ProviderFixture:
url = os.getenv("CHROMA_URL")
if url:
config = ChromaRemoteImplConfig(url=url)
provider_type = "remote::chromadb"
else:
if not os.getenv("CHROMA_DB_PATH"):
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
config = ChromaInlineImplConfig(db_path=os.getenv("CHROMA_DB_PATH"))
provider_type = "inline::chromadb"
return ProviderFixture( return ProviderFixture(
providers=[ providers=[
Provider( Provider(
provider_id="chroma", provider_id="chroma",
provider_type="remote::chromadb", provider_type=provider_type,
config=RemoteProviderConfig( config=config.model_dump(),
host=get_env_or_fail("CHROMA_HOST"),
port=get_env_or_fail("CHROMA_PORT"),
).model_dump(),
) )
] ]
) )