updated tests and adpaters to include chroma

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-07-23 15:14:31 -04:00
parent 6cd339a2f2
commit 0c24d0cc41
3 changed files with 56 additions and 158 deletions

View file

@ -6,12 +6,10 @@
import asyncio import asyncio
import json import json
import logging import logging
import uuid
from typing import Any from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
import chromadb import chromadb
from chromadb.errors import NotFoundError
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
@ -20,24 +18,7 @@ from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
Chunk, Chunk,
QueryChunksResponse, QueryChunksResponse,
SearchRankingOptions,
VectorIO, VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
VectorStoreFileDeleteResponse,
)
from llama_stack.apis.vector_io.vector_io import (
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreListFilesResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
) )
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
@ -138,7 +119,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self, self,
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig, config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
inference_api: Api.inference, inference_api: Api.inference,
files_api: Files | None files_api: Files | None,
) -> None: ) -> None:
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config self.config = config
@ -216,133 +197,3 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api) index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
self.cache[vector_db_id] = index self.cache[vector_db_id] = index
return index return index
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
try:
collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name))
except NotFoundError:
collection = await maybe_await(
self.client.create_collection(name=self.metadata_collection_name, metadata={
"description": "Collection to store metadata for OpenAI vector stores"
})
)
await maybe_await(
collection.add(
ids=[store_id],
metadatas=[{"store_id": store_id, "metadata": json.dumps(store_info)}],
)
)
self.openai_vector_stores[store_id] = store_info
except Exception as e:
log.error(f"Error saving openai vector store {store_id}: {e}")
raise
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
openai_vector_stores = {}
try:
collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name))
except NotFoundError:
return openai_vector_stores
try:
collection_count = await maybe_await(collection.count())
if collection_count == 0:
return openai_vector_stores
offset = 0
batch_size = 100
while True:
result = await maybe_await(
collection.get(
where={"store_id": {"$exists": True}},
offset=offset,
limit=batch_size,
include=["documents", "metadatas"],
)
)
if not result['ids'] or len(result['ids']) == 0:
break
for i, doc_id in enumerate(result['ids']):
metadata = result.get('metadatas', [{}])[i] if i < len(result.get('metadatas', [])) else {}
# Extract store_id (assuming it's in metadata)
store_id = metadata.get('store_id')
if store_id:
# If metadata contains JSON string, parse it
metadata_json = metadata.get('metadata')
if metadata_json:
try:
if isinstance(metadata_json, str):
store_info = json.loads(metadata_json)
else:
store_info = metadata_json
openai_vector_stores[store_id] = store_info
except json.JSONDecodeError:
log.error(f"failed to decode metadata for store_id {store_id}")
offset += batch_size
except Exception as e:
log.error(f"error loading openai vector stores: {e}")
return openai_vector_stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
try:
if store_id in self.openai_vector_stores:
collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name))
await maybe_await(
collection.update(
ids=[store_id],
metadatas=[{"store_id": store_id, "metadata": json.dumps(store_info)}],
)
)
self.openai_vector_stores[store_id] = store_info
except NotFoundError:
log.error(f"Collection {self.metadata_collection_name} not found")
except Exception as e:
log.error(f"Error updating openai vector store {store_id}: {e}")
raise
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
try:
collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name))
await maybe_await(collection.delete(ids=[store_id]))
except ValueError:
log.error(f"Collection {self.metadata_collection_name} not found")
except Exception as e:
log.error(f"Error deleting openai vector store {store_id}: {e}")
raise
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from persistent storage."""
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> VectorStoreListFilesResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from persistent storage."""
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from persistent storage."""
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to persistent storage."""
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in persistent storage."""
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")

View file

@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
for p in vector_io_providers: for p in vector_io_providers:
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::chromadb]: if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::chromadb"]:
return return
pytest.skip("OpenAI vector stores are not supported by any provider") pytest.skip("OpenAI vector stores are not supported by any provider")
@ -31,7 +31,13 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models): def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models):
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
for p in vector_io_providers: for p in vector_io_providers:
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::pgvector", "inline::chromadb"]: if p.provider_type in [
"inline::faiss",
"inline::sqlite-vec",
"inline::milvus",
"remote::pgvector",
"inline::chromadb",
]:
return return
pytest.skip("OpenAI vector stores are not supported by any provider") pytest.skip("OpenAI vector stores are not supported by any provider")

View file

@ -12,11 +12,13 @@ from pymilvus import MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata from llama_stack.apis.vector_io import Chunk, ChunkMetadata
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
EMBEDDING_DIMENSION = 384 EMBEDDING_DIMENSION = 384
@ -236,15 +238,54 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
await adapter.shutdown() await adapter.shutdown()
@pytest.fixture
def chroma_vec_db_path(tmp_path_factory):
persist_dir = tmp_path_factory.mktemp(f"chroma_{np.random.randint(1e6)}")
return str(persist_dir)
@pytest.fixture
async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
index = ChromaIndex(
embedding_dimension=embedding_dimension,
persist_directory=chroma_vec_db_path,
)
await index.initialize()
yield index
await index.delete()
@pytest.fixture
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
config = ChromaVectorIOConfig(persist_directory=chroma_vec_db_path)
adapter = ChromaVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=f"chroma_test_collection_{random.randint(1, 1_000_000)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture @pytest.fixture
def vector_io_adapter(vector_provider, request): def vector_io_adapter(vector_provider, request):
"""Returns the appropriate vector IO adapter based on the provider parameter.""" """Returns the appropriate vector IO adapter based on the provider parameter."""
if vector_provider == "milvus": vector_provider_dict = {
return request.getfixturevalue("milvus_vec_adapter") "milvus": "milvus_vec_adapter",
elif vector_provider == "faiss": "faiss": "faiss_vec_adapter",
return request.getfixturevalue("faiss_vec_adapter") "sqlite_vec": "sqlite_vec_adapter",
else: "chroma": "chroma_vec_adapter",
return request.getfixturevalue("sqlite_vec_adapter") }
return request.getfixturevalue(vector_provider_dict[vector_provider])
@pytest.fixture @pytest.fixture