From 6cd339a2f25f8f7c007dea88e2f4595ac8ce6f91 Mon Sep 17 00:00:00 2001 From: sarthakdeshpande Date: Sun, 22 Jun 2025 22:05:20 +0530 Subject: [PATCH] chore: Added openai compatible vector io endpoints for chromadb --- .../inline/vector_io/chroma/__init__.py | 2 +- .../remote/vector_io/chroma/__init__.py | 2 +- .../remote/vector_io/chroma/chroma.py | 202 +++++++++++------- .../vector_io/test_openai_vector_stores.py | 4 +- 4 files changed, 123 insertions(+), 87 deletions(-) diff --git a/llama_stack/providers/inline/vector_io/chroma/__init__.py b/llama_stack/providers/inline/vector_io/chroma/__init__.py index 2e0efb8a1..988c4b4b6 100644 --- a/llama_stack/providers/inline/vector_io/chroma/__init__.py +++ b/llama_stack/providers/inline/vector_io/chroma/__init__.py @@ -16,6 +16,6 @@ async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]): ChromaVectorIOAdapter, ) - impl = ChromaVectorIOAdapter(config, deps[Api.inference]) + impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/chroma/__init__.py b/llama_stack/providers/remote/vector_io/chroma/__init__.py index ebbc62b1c..e4b77c68d 100644 --- a/llama_stack/providers/remote/vector_io/chroma/__init__.py +++ b/llama_stack/providers/remote/vector_io/chroma/__init__.py @@ -12,6 +12,6 @@ from .config import ChromaVectorIOConfig async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]): from .chroma import ChromaVectorIOAdapter - impl = ChromaVectorIOAdapter(config, deps[Api.inference]) + impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index ffe2cba44..1f059acbf 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -6,12 +6,15 @@ import asyncio import json import logging +import uuid from typing import Any from urllib.parse import urlparse import chromadb +from chromadb.errors import NotFoundError from numpy.typing import NDArray +from llama_stack.apis.files import Files from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( @@ -19,6 +22,13 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, SearchRankingOptions, VectorIO, + VectorStoreDeleteResponse, + VectorStoreListResponse, + VectorStoreObject, + VectorStoreSearchResponsePage, + VectorStoreFileDeleteResponse, +) +from llama_stack.apis.vector_io.vector_io import ( VectorStoreChunkingStrategy, VectorStoreDeleteResponse, VectorStoreFileContentsResponse, @@ -31,6 +41,7 @@ from llama_stack.apis.vector_io import ( ) from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig +from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, @@ -122,11 +133,12 @@ class ChromaIndex(EmbeddingIndex): raise NotImplementedError("Hybrid search is not supported in Chroma") -class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): +class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__( self, config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig, inference_api: Api.inference, + files_api: Files | None ) -> None: log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") self.config = config @@ -137,9 +149,6 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def initialize(self) -> None: if isinstance(self.config, RemoteChromaVectorIOConfig): - if not self.config.url: - raise ValueError("URL is a required parameter for the remote Chroma provider's config") - log.info(f"Connecting to Chroma server at: {self.config.url}") url = self.config.url.rstrip("/") parsed = urlparse(url) @@ -151,6 +160,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): else: log.info(f"Connecting to Chroma local db at: {self.config.db_path}") self.client = chromadb.PersistentClient(path=self.config.db_path) + self.openai_vector_stores = await self._load_openai_vector_stores() async def shutdown(self) -> None: pass @@ -207,70 +217,107 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self.cache[vector_db_id] = index return index - async def openai_create_vector_store( - self, - name: str, - file_ids: list[str] | None = None, - expires_after: dict[str, Any] | None = None, - chunking_strategy: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - embedding_model: str | None = None, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - ) -> VectorStoreObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") - async def openai_list_vector_stores( - self, - limit: int | None = 20, - order: str | None = "desc", - after: str | None = None, - before: str | None = None, - ) -> VectorStoreListResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") + async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + try: + collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name)) + except NotFoundError: + collection = await maybe_await( + self.client.create_collection(name=self.metadata_collection_name, metadata={ + "description": "Collection to store metadata for OpenAI vector stores" + }) + ) - async def openai_retrieve_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") + await maybe_await( + collection.add( + ids=[store_id], + metadatas=[{"store_id": store_id, "metadata": json.dumps(store_info)}], + ) + ) - async def openai_update_vector_store( - self, - vector_store_id: str, - name: str | None = None, - expires_after: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - ) -> VectorStoreObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") + self.openai_vector_stores[store_id] = store_info - async def openai_delete_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreDeleteResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") + except Exception as e: + log.error(f"Error saving openai vector store {store_id}: {e}") + raise - async def openai_search_vector_store( - self, - vector_store_id: str, - query: str | list[str], - filters: dict[str, Any] | None = None, - max_num_results: int | None = 10, - ranking_options: SearchRankingOptions | None = None, - rewrite_query: bool | None = False, - search_mode: str | None = "vector", - ) -> VectorStoreSearchResponsePage: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") + async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: + openai_vector_stores = {} + try: + collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name)) + except NotFoundError: + return openai_vector_stores - async def openai_attach_file_to_vector_store( - self, - vector_store_id: str, - file_id: str, - attributes: dict[str, Any] | None = None, - chunking_strategy: VectorStoreChunkingStrategy | None = None, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") + try: + collection_count = await maybe_await(collection.count()) + if collection_count == 0: + return openai_vector_stores + offset = 0 + batch_size = 100 + while True: + result = await maybe_await( + collection.get( + where={"store_id": {"$exists": True}}, + offset=offset, + limit=batch_size, + include=["documents", "metadatas"], + ) + ) + if not result['ids'] or len(result['ids']) == 0: + break + for i, doc_id in enumerate(result['ids']): + metadata = result.get('metadatas', [{}])[i] if i < len(result.get('metadatas', [])) else {} + + # Extract store_id (assuming it's in metadata) + store_id = metadata.get('store_id') + + if store_id: + # If metadata contains JSON string, parse it + metadata_json = metadata.get('metadata') + if metadata_json: + try: + if isinstance(metadata_json, str): + store_info = json.loads(metadata_json) + else: + store_info = metadata_json + openai_vector_stores[store_id] = store_info + except json.JSONDecodeError: + log.error(f"failed to decode metadata for store_id {store_id}") + offset += batch_size + except Exception as e: + log.error(f"error loading openai vector stores: {e}") + return openai_vector_stores + + async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + try: + if store_id in self.openai_vector_stores: + collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name)) + await maybe_await( + collection.update( + ids=[store_id], + metadatas=[{"store_id": store_id, "metadata": json.dumps(store_info)}], + ) + ) + self.openai_vector_stores[store_id] = store_info + except NotFoundError: + log.error(f"Collection {self.metadata_collection_name} not found") + except Exception as e: + log.error(f"Error updating openai vector store {store_id}: {e}") + raise + + async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: + try: + collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name)) + await maybe_await(collection.delete(ids=[store_id])) + except ValueError: + log.error(f"Collection {self.metadata_collection_name} not found") + except Exception as e: + log.error(f"Error deleting openai vector store {store_id}: {e}") + raise + + async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: + """Delete vector store file metadata from persistent storage.""" async def openai_list_files_in_vector_store( self, vector_store_id: str, @@ -282,31 +329,20 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): ) -> VectorStoreListFilesResponse: raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") - async def openai_retrieve_vector_store_file( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileObject: + async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]: + """Load vector store file metadata from persistent storage.""" raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") - async def openai_retrieve_vector_store_file_contents( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileContentsResponse: + async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]: + """Load vector store file contents from persistent storage.""" raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") - async def openai_update_vector_store_file( - self, - vector_store_id: str, - file_id: str, - attributes: dict[str, Any] | None = None, - ) -> VectorStoreFileObject: + async def _save_openai_vector_store_file( + self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] + ) -> None: + """Save vector store file metadata to persistent storage.""" raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") - async def openai_delete_vector_store_file( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") + async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None: + """Update vector store file metadata in persistent storage.""" + raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") \ No newline at end of file diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index d7300348b..ec89c0bd1 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] for p in vector_io_providers: - if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus"]: + if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::chromadb]: return pytest.skip("OpenAI vector stores are not supported by any provider") @@ -31,7 +31,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models): vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] for p in vector_io_providers: - if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::pgvector"]: + if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::pgvector", "inline::chromadb"]: return pytest.skip("OpenAI vector stores are not supported by any provider")