This commit is contained in:
Sarthak Deshpande 2025-06-24 15:07:08 +02:00 committed by GitHub
commit 76afaf302f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 127 additions and 95 deletions

View file

@ -16,6 +16,6 @@ async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
ChromaVectorIOAdapter, ChromaVectorIOAdapter,
) )
impl = ChromaVectorIOAdapter(config, deps[Api.inference]) impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -12,6 +12,6 @@ from .config import ChromaVectorIOConfig
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]): async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .chroma import ChromaVectorIOAdapter from .chroma import ChromaVectorIOAdapter
impl = ChromaVectorIOAdapter(config, deps[Api.inference]) impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -6,12 +6,15 @@
import asyncio import asyncio
import json import json
import logging import logging
import uuid
from typing import Any from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
import chromadb import chromadb
from chromadb.errors import NotFoundError
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
@ -23,6 +26,7 @@ from llama_stack.apis.vector_io import (
VectorStoreListResponse, VectorStoreListResponse,
VectorStoreObject, VectorStoreObject,
VectorStoreSearchResponsePage, VectorStoreSearchResponsePage,
VectorStoreFileDeleteResponse,
) )
from llama_stack.apis.vector_io.vector_io import ( from llama_stack.apis.vector_io.vector_io import (
VectorStoreChunkingStrategy, VectorStoreChunkingStrategy,
@ -32,6 +36,7 @@ from llama_stack.apis.vector_io.vector_io import (
) )
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
@ -123,16 +128,20 @@ class ChromaIndex(EmbeddingIndex):
raise NotImplementedError("Hybrid search is not supported in Chroma") raise NotImplementedError("Hybrid search is not supported in Chroma")
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__( def __init__(
self, self,
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig, config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
inference_api: Api.inference, inference_api: Api.inference,
files_api: Files | None
) -> None: ) -> None:
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.files_api = files_api
self.metadata_collection_name = "openai_vector_stores_metadata"
self.client = None self.client = None
self.cache = {} self.cache = {}
@ -149,6 +158,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
else: else:
log.info(f"Connecting to Chroma local db at: {self.config.db_path}") log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
self.client = chromadb.PersistentClient(path=self.config.db_path) self.client = chromadb.PersistentClient(path=self.config.db_path)
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
@ -205,101 +215,123 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self.cache[vector_db_id] = index self.cache[vector_db_id] = index
return index return index
async def openai_create_vector_store(
self, async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
name: str, try:
file_ids: list[str] | None = None, collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name))
expires_after: dict[str, Any] | None = None, except NotFoundError:
chunking_strategy: dict[str, Any] | None = None, collection = await maybe_await(
metadata: dict[str, Any] | None = None, self.client.create_collection(name=self.metadata_collection_name, metadata={
embedding_model: str | None = None, "description": "Collection to store metadata for OpenAI vector stores"
embedding_dimension: int | None = 384, })
provider_id: str | None = None, )
provider_vector_db_id: str | None = None,
) -> VectorStoreObject: await maybe_await(
collection.add(
ids=[store_id],
metadatas=[{"store_id": store_id, "metadata": json.dumps(store_info)}],
)
)
self.openai_vector_stores[store_id] = store_info
except Exception as e:
log.error(f"Error saving openai vector store {store_id}: {e}")
raise
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
openai_vector_stores = {}
try:
collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name))
except NotFoundError:
return openai_vector_stores
try:
collection_count = await maybe_await(collection.count())
if collection_count == 0:
return openai_vector_stores
offset = 0
batch_size = 100
while True:
result = await maybe_await(
collection.get(
where={"store_id": {"$exists": True}},
offset=offset,
limit=batch_size,
include=["documents", "metadatas"],
)
)
if not result['ids'] or len(result['ids']) == 0:
break
for i, doc_id in enumerate(result['ids']):
metadata = result.get('metadatas', [{}])[i] if i < len(result.get('metadatas', [])) else {}
# Extract store_id (assuming it's in metadata)
store_id = metadata.get('store_id')
if store_id:
# If metadata contains JSON string, parse it
metadata_json = metadata.get('metadata')
if metadata_json:
try:
if isinstance(metadata_json, str):
store_info = json.loads(metadata_json)
else:
store_info = metadata_json
openai_vector_stores[store_id] = store_info
except json.JSONDecodeError:
log.error(f"failed to decode metadata for store_id {store_id}")
offset += batch_size
except Exception as e:
log.error(f"error loading openai vector stores: {e}")
return openai_vector_stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
try:
if store_id in self.openai_vector_stores:
collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name))
await maybe_await(
collection.update(
ids=[store_id],
metadatas=[{"store_id": store_id, "metadata": json.dumps(store_info)}],
)
)
self.openai_vector_stores[store_id] = store_info
except NotFoundError:
log.error(f"Collection {self.metadata_collection_name} not found")
except Exception as e:
log.error(f"Error updating openai vector store {store_id}: {e}")
raise
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
try:
collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name))
await maybe_await(collection.delete(ids=[store_id]))
except ValueError:
log.error(f"Collection {self.metadata_collection_name} not found")
except Exception as e:
log.error(f"Error deleting openai vector store {store_id}: {e}")
raise
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from persistent storage."""
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_list_vector_stores( async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
self, """Load vector store file metadata from persistent storage."""
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_retrieve_vector_store( async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
self, """Load vector store file contents from persistent storage."""
vector_store_id: str,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_update_vector_store( async def _save_openai_vector_store_file(
self, self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
vector_store_id: str, ) -> None:
name: str | None = None, """Save vector store file metadata to persistent storage."""
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_delete_vector_store( async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
self, """Update vector store file metadata in persistent storage."""
vector_store_id: str,
) -> VectorStoreDeleteResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
) -> VectorStoreListFilesResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")

View file

@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
for p in vector_io_providers: for p in vector_io_providers:
if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]: if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::chromadb"]:
return return
pytest.skip("OpenAI vector stores are not supported by any provider") pytest.skip("OpenAI vector stores are not supported by any provider")