diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 1f059acbf..142e7f701 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -6,12 +6,10 @@ import asyncio import json import logging -import uuid from typing import Any from urllib.parse import urlparse import chromadb -from chromadb.errors import NotFoundError from numpy.typing import NDArray from llama_stack.apis.files import Files @@ -20,24 +18,7 @@ from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( Chunk, QueryChunksResponse, - SearchRankingOptions, VectorIO, - VectorStoreDeleteResponse, - VectorStoreListResponse, - VectorStoreObject, - VectorStoreSearchResponsePage, - VectorStoreFileDeleteResponse, -) -from llama_stack.apis.vector_io.vector_io import ( - VectorStoreChunkingStrategy, - VectorStoreDeleteResponse, - VectorStoreFileContentsResponse, - VectorStoreFileObject, - VectorStoreFileStatus, - VectorStoreListFilesResponse, - VectorStoreListResponse, - VectorStoreObject, - VectorStoreSearchResponsePage, ) from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig @@ -138,7 +119,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self, config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig, inference_api: Api.inference, - files_api: Files | None + files_api: Files | None, ) -> None: log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") self.config = config @@ -216,133 +197,3 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api) self.cache[vector_db_id] = index return index - - - async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - try: - collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name)) - except NotFoundError: - collection = await maybe_await( - self.client.create_collection(name=self.metadata_collection_name, metadata={ - "description": "Collection to store metadata for OpenAI vector stores" - }) - ) - - await maybe_await( - collection.add( - ids=[store_id], - metadatas=[{"store_id": store_id, "metadata": json.dumps(store_info)}], - ) - ) - - self.openai_vector_stores[store_id] = store_info - - except Exception as e: - log.error(f"Error saving openai vector store {store_id}: {e}") - raise - - async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: - openai_vector_stores = {} - try: - collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name)) - except NotFoundError: - return openai_vector_stores - - try: - collection_count = await maybe_await(collection.count()) - if collection_count == 0: - return openai_vector_stores - offset = 0 - batch_size = 100 - while True: - result = await maybe_await( - collection.get( - where={"store_id": {"$exists": True}}, - offset=offset, - limit=batch_size, - include=["documents", "metadatas"], - ) - ) - if not result['ids'] or len(result['ids']) == 0: - break - - for i, doc_id in enumerate(result['ids']): - metadata = result.get('metadatas', [{}])[i] if i < len(result.get('metadatas', [])) else {} - - # Extract store_id (assuming it's in metadata) - store_id = metadata.get('store_id') - - if store_id: - # If metadata contains JSON string, parse it - metadata_json = metadata.get('metadata') - if metadata_json: - try: - if isinstance(metadata_json, str): - store_info = json.loads(metadata_json) - else: - store_info = metadata_json - openai_vector_stores[store_id] = store_info - except json.JSONDecodeError: - log.error(f"failed to decode metadata for store_id {store_id}") - offset += batch_size - except Exception as e: - log.error(f"error loading openai vector stores: {e}") - return openai_vector_stores - - async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - try: - if store_id in self.openai_vector_stores: - collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name)) - await maybe_await( - collection.update( - ids=[store_id], - metadatas=[{"store_id": store_id, "metadata": json.dumps(store_info)}], - ) - ) - self.openai_vector_stores[store_id] = store_info - except NotFoundError: - log.error(f"Collection {self.metadata_collection_name} not found") - except Exception as e: - log.error(f"Error updating openai vector store {store_id}: {e}") - raise - - async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: - try: - collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name)) - await maybe_await(collection.delete(ids=[store_id])) - except ValueError: - log.error(f"Collection {self.metadata_collection_name} not found") - except Exception as e: - log.error(f"Error deleting openai vector store {store_id}: {e}") - raise - - async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: - """Delete vector store file metadata from persistent storage.""" - async def openai_list_files_in_vector_store( - self, - vector_store_id: str, - limit: int | None = 20, - order: str | None = "desc", - after: str | None = None, - before: str | None = None, - filter: VectorStoreFileStatus | None = None, - ) -> VectorStoreListFilesResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") - - async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]: - """Load vector store file metadata from persistent storage.""" - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") - - async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]: - """Load vector store file contents from persistent storage.""" - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") - - async def _save_openai_vector_store_file( - self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] - ) -> None: - """Save vector store file metadata to persistent storage.""" - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") - - async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None: - """Update vector store file metadata in persistent storage.""" - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") \ No newline at end of file diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index ec89c0bd1..5dcd321c3 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] for p in vector_io_providers: - if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::chromadb]: + if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::chromadb"]: return pytest.skip("OpenAI vector stores are not supported by any provider") @@ -31,7 +31,13 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models): vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] for p in vector_io_providers: - if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::pgvector", "inline::chromadb"]: + if p.provider_type in [ + "inline::faiss", + "inline::sqlite-vec", + "inline::milvus", + "remote::pgvector", + "inline::chromadb", + ]: return pytest.skip("OpenAI vector stores are not supported by any provider") diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 9f86f877d..45e37d6ff 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -12,11 +12,13 @@ from pymilvus import MilvusClient, connections from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, ChunkMetadata +from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter +from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter EMBEDDING_DIMENSION = 384 @@ -236,15 +238,54 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding await adapter.shutdown() +@pytest.fixture +def chroma_vec_db_path(tmp_path_factory): + persist_dir = tmp_path_factory.mktemp(f"chroma_{np.random.randint(1e6)}") + return str(persist_dir) + + +@pytest.fixture +async def chroma_vec_index(chroma_vec_db_path, embedding_dimension): + index = ChromaIndex( + embedding_dimension=embedding_dimension, + persist_directory=chroma_vec_db_path, + ) + await index.initialize() + yield index + await index.delete() + + +@pytest.fixture +async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension): + config = ChromaVectorIOConfig(persist_directory=chroma_vec_db_path) + adapter = ChromaVectorIOAdapter( + config=config, + inference_api=mock_inference_api, + files_api=None, + ) + await adapter.initialize() + await adapter.register_vector_db( + VectorDB( + identifier=f"chroma_test_collection_{random.randint(1, 1_000_000)}", + provider_id="test_provider", + embedding_model="test_model", + embedding_dimension=embedding_dimension, + ) + ) + yield adapter + await adapter.shutdown() + + @pytest.fixture def vector_io_adapter(vector_provider, request): """Returns the appropriate vector IO adapter based on the provider parameter.""" - if vector_provider == "milvus": - return request.getfixturevalue("milvus_vec_adapter") - elif vector_provider == "faiss": - return request.getfixturevalue("faiss_vec_adapter") - else: - return request.getfixturevalue("sqlite_vec_adapter") + vector_provider_dict = { + "milvus": "milvus_vec_adapter", + "faiss": "faiss_vec_adapter", + "sqlite_vec": "sqlite_vec_adapter", + "chroma": "chroma_vec_adapter", + } + return request.getfixturevalue(vector_provider_dict[vector_provider]) @pytest.fixture