mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
updated tests and adpaters to include chroma
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
6cd339a2f2
commit
0c24d0cc41
3 changed files with 56 additions and 158 deletions
|
|
@ -6,12 +6,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.errors import NotFoundError
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
|
|
@ -20,24 +18,7 @@ from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
SearchRankingOptions,
|
|
||||||
VectorIO,
|
VectorIO,
|
||||||
VectorStoreDeleteResponse,
|
|
||||||
VectorStoreListResponse,
|
|
||||||
VectorStoreObject,
|
|
||||||
VectorStoreSearchResponsePage,
|
|
||||||
VectorStoreFileDeleteResponse,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.vector_io.vector_io import (
|
|
||||||
VectorStoreChunkingStrategy,
|
|
||||||
VectorStoreDeleteResponse,
|
|
||||||
VectorStoreFileContentsResponse,
|
|
||||||
VectorStoreFileObject,
|
|
||||||
VectorStoreFileStatus,
|
|
||||||
VectorStoreListFilesResponse,
|
|
||||||
VectorStoreListResponse,
|
|
||||||
VectorStoreObject,
|
|
||||||
VectorStoreSearchResponsePage,
|
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||||
|
|
@ -138,7 +119,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
self,
|
self,
|
||||||
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
|
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
|
||||||
inference_api: Api.inference,
|
inference_api: Api.inference,
|
||||||
files_api: Files | None
|
files_api: Files | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
@ -216,133 +197,3 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
|
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
|
||||||
self.cache[vector_db_id] = index
|
self.cache[vector_db_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
||||||
|
|
||||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
try:
|
|
||||||
collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name))
|
|
||||||
except NotFoundError:
|
|
||||||
collection = await maybe_await(
|
|
||||||
self.client.create_collection(name=self.metadata_collection_name, metadata={
|
|
||||||
"description": "Collection to store metadata for OpenAI vector stores"
|
|
||||||
})
|
|
||||||
)
|
|
||||||
|
|
||||||
await maybe_await(
|
|
||||||
collection.add(
|
|
||||||
ids=[store_id],
|
|
||||||
metadatas=[{"store_id": store_id, "metadata": json.dumps(store_info)}],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error saving openai vector store {store_id}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
|
||||||
openai_vector_stores = {}
|
|
||||||
try:
|
|
||||||
collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name))
|
|
||||||
except NotFoundError:
|
|
||||||
return openai_vector_stores
|
|
||||||
|
|
||||||
try:
|
|
||||||
collection_count = await maybe_await(collection.count())
|
|
||||||
if collection_count == 0:
|
|
||||||
return openai_vector_stores
|
|
||||||
offset = 0
|
|
||||||
batch_size = 100
|
|
||||||
while True:
|
|
||||||
result = await maybe_await(
|
|
||||||
collection.get(
|
|
||||||
where={"store_id": {"$exists": True}},
|
|
||||||
offset=offset,
|
|
||||||
limit=batch_size,
|
|
||||||
include=["documents", "metadatas"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not result['ids'] or len(result['ids']) == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
for i, doc_id in enumerate(result['ids']):
|
|
||||||
metadata = result.get('metadatas', [{}])[i] if i < len(result.get('metadatas', [])) else {}
|
|
||||||
|
|
||||||
# Extract store_id (assuming it's in metadata)
|
|
||||||
store_id = metadata.get('store_id')
|
|
||||||
|
|
||||||
if store_id:
|
|
||||||
# If metadata contains JSON string, parse it
|
|
||||||
metadata_json = metadata.get('metadata')
|
|
||||||
if metadata_json:
|
|
||||||
try:
|
|
||||||
if isinstance(metadata_json, str):
|
|
||||||
store_info = json.loads(metadata_json)
|
|
||||||
else:
|
|
||||||
store_info = metadata_json
|
|
||||||
openai_vector_stores[store_id] = store_info
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
log.error(f"failed to decode metadata for store_id {store_id}")
|
|
||||||
offset += batch_size
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"error loading openai vector stores: {e}")
|
|
||||||
return openai_vector_stores
|
|
||||||
|
|
||||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
try:
|
|
||||||
if store_id in self.openai_vector_stores:
|
|
||||||
collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name))
|
|
||||||
await maybe_await(
|
|
||||||
collection.update(
|
|
||||||
ids=[store_id],
|
|
||||||
metadatas=[{"store_id": store_id, "metadata": json.dumps(store_info)}],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
except NotFoundError:
|
|
||||||
log.error(f"Collection {self.metadata_collection_name} not found")
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error updating openai vector store {store_id}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
|
||||||
try:
|
|
||||||
collection = await maybe_await(self.client.get_collection(name=self.metadata_collection_name))
|
|
||||||
await maybe_await(collection.delete(ids=[store_id]))
|
|
||||||
except ValueError:
|
|
||||||
log.error(f"Collection {self.metadata_collection_name} not found")
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error deleting openai vector store {store_id}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
|
||||||
"""Delete vector store file metadata from persistent storage."""
|
|
||||||
async def openai_list_files_in_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
limit: int | None = 20,
|
|
||||||
order: str | None = "desc",
|
|
||||||
after: str | None = None,
|
|
||||||
before: str | None = None,
|
|
||||||
filter: VectorStoreFileStatus | None = None,
|
|
||||||
) -> VectorStoreListFilesResponse:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
|
||||||
|
|
||||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
|
||||||
"""Load vector store file metadata from persistent storage."""
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
|
||||||
|
|
||||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
|
||||||
"""Load vector store file contents from persistent storage."""
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
|
||||||
|
|
||||||
async def _save_openai_vector_store_file(
|
|
||||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
|
||||||
) -> None:
|
|
||||||
"""Save vector store file metadata to persistent storage."""
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
|
||||||
|
|
||||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
|
||||||
"""Update vector store file metadata in persistent storage."""
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
|
||||||
|
|
@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||||
for p in vector_io_providers:
|
for p in vector_io_providers:
|
||||||
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::chromadb]:
|
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::chromadb"]:
|
||||||
return
|
return
|
||||||
|
|
||||||
pytest.skip("OpenAI vector stores are not supported by any provider")
|
pytest.skip("OpenAI vector stores are not supported by any provider")
|
||||||
|
|
@ -31,7 +31,13 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||||
def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models):
|
def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models):
|
||||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||||
for p in vector_io_providers:
|
for p in vector_io_providers:
|
||||||
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::pgvector", "inline::chromadb"]:
|
if p.provider_type in [
|
||||||
|
"inline::faiss",
|
||||||
|
"inline::sqlite-vec",
|
||||||
|
"inline::milvus",
|
||||||
|
"remote::pgvector",
|
||||||
|
"inline::chromadb",
|
||||||
|
]:
|
||||||
return
|
return
|
||||||
|
|
||||||
pytest.skip("OpenAI vector stores are not supported by any provider")
|
pytest.skip("OpenAI vector stores are not supported by any provider")
|
||||||
|
|
|
||||||
|
|
@ -12,11 +12,13 @@ from pymilvus import MilvusClient, connections
|
||||||
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
||||||
|
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||||
|
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter
|
||||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||||
|
|
||||||
EMBEDDING_DIMENSION = 384
|
EMBEDDING_DIMENSION = 384
|
||||||
|
|
@ -236,15 +238,54 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
|
||||||
await adapter.shutdown()
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def chroma_vec_db_path(tmp_path_factory):
|
||||||
|
persist_dir = tmp_path_factory.mktemp(f"chroma_{np.random.randint(1e6)}")
|
||||||
|
return str(persist_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
|
||||||
|
index = ChromaIndex(
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
persist_directory=chroma_vec_db_path,
|
||||||
|
)
|
||||||
|
await index.initialize()
|
||||||
|
yield index
|
||||||
|
await index.delete()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
|
||||||
|
config = ChromaVectorIOConfig(persist_directory=chroma_vec_db_path)
|
||||||
|
adapter = ChromaVectorIOAdapter(
|
||||||
|
config=config,
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
await adapter.initialize()
|
||||||
|
await adapter.register_vector_db(
|
||||||
|
VectorDB(
|
||||||
|
identifier=f"chroma_test_collection_{random.randint(1, 1_000_000)}",
|
||||||
|
provider_id="test_provider",
|
||||||
|
embedding_model="test_model",
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield adapter
|
||||||
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vector_io_adapter(vector_provider, request):
|
def vector_io_adapter(vector_provider, request):
|
||||||
"""Returns the appropriate vector IO adapter based on the provider parameter."""
|
"""Returns the appropriate vector IO adapter based on the provider parameter."""
|
||||||
if vector_provider == "milvus":
|
vector_provider_dict = {
|
||||||
return request.getfixturevalue("milvus_vec_adapter")
|
"milvus": "milvus_vec_adapter",
|
||||||
elif vector_provider == "faiss":
|
"faiss": "faiss_vec_adapter",
|
||||||
return request.getfixturevalue("faiss_vec_adapter")
|
"sqlite_vec": "sqlite_vec_adapter",
|
||||||
else:
|
"chroma": "chroma_vec_adapter",
|
||||||
return request.getfixturevalue("sqlite_vec_adapter")
|
}
|
||||||
|
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue