Create a separate inline::chromadb provider

This commit is contained in:
Ashwin Bharambe 2024-12-11 14:11:08 -08:00
parent 44ab7d93fb
commit f0e045d1c8
7 changed files with 73 additions and 45 deletions

View file

@ -10,8 +10,10 @@ import tempfile
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
@ -79,15 +81,21 @@ def memory_weaviate() -> ProviderFixture:
@pytest.fixture(scope="session")
def memory_chroma() -> ProviderFixture:
url = os.getenv("CHROMA_URL")
if url:
config = ChromaRemoteImplConfig(url=url)
provider_type = "remote::chromadb"
else:
if not os.getenv("CHROMA_DB_PATH"):
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
config = ChromaInlineImplConfig(db_path=os.getenv("CHROMA_DB_PATH"))
provider_type = "inline::chromadb"
return ProviderFixture(
providers=[
Provider(
provider_id="chroma",
provider_type="remote::chromadb",
config=RemoteProviderConfig(
host=get_env_or_fail("CHROMA_HOST"),
port=get_env_or_fail("CHROMA_PORT"),
).model_dump(),
provider_type=provider_type,
config=config.model_dump(),
)
]
)