diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 91f400fe9..fbb3d8b8d 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -16,12 +16,13 @@ from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as RemoteChromaVectorIOConfig from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) -from .config import ChromaVectorIOConfig +from .config import ChromaVectorIOConfig as InlineChromaVectorIOConfig log = logging.getLogger(__name__) @@ -88,7 +89,7 @@ class ChromaIndex(EmbeddingIndex): class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( self, - config: ChromaVectorIOConfig, + config: Union[RemoteChromaVectorIOConfig, InlineChromaVectorIOConfig], inference_api: Api.inference, ) -> None: log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") @@ -99,7 +100,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self.cache = {} async def initialize(self) -> None: - if isinstance(self.config, ChromaVectorIOConfig): + if isinstance(self.config, RemoteChromaVectorIOConfig): log.info(f"Connecting to Chroma server at: {self.config.url}") url = self.config.url.rstrip("/") parsed = urlparse(url)