mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-25 21:57:45 +00:00
chore: Enabling teste for Weaviate
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> chore: Actually enabling Chroma unit tests Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> fixed tests Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> fix integration test Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> remove changes from weavbiate Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
cd8715d327
commit
2defebc835
3 changed files with 56 additions and 28 deletions
|
@ -57,12 +57,16 @@ class ChromaIndex(EmbeddingIndex):
|
|||
self.collection = collection
|
||||
self.kvstore = kvstore
|
||||
|
||||
async def initialize(self):
|
||||
# Chroma does not require explicit initialization, this is just a helper for unit tests
|
||||
pass
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
|
||||
ids = [f"{c.metadata['document_id']}:chunk-{i}" for i, c in enumerate(chunks)]
|
||||
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
|
||||
await maybe_await(
|
||||
self.collection.add(
|
||||
documents=[chunk.model_dump_json() for chunk in chunks],
|
||||
|
@ -137,9 +141,12 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.client = None
|
||||
self.cache = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
self.vector_db_store = self.kvstore
|
||||
|
||||
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
||||
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
url = self.config.url.rstrip("/")
|
||||
|
@ -172,6 +179,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
)
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
log.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
|
@ -182,6 +193,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
|
@ -193,18 +206,27 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Llama Stack")
|
||||
collection = await maybe_await(self.client.get_collection(vector_db_id))
|
||||
if not collection:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
|
||||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
try:
|
||||
collection = await maybe_await(self.client.get_collection(vector_db_id))
|
||||
if not collection:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Llama Stack")
|
||||
|
||||
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
|
||||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma") from exc
|
||||
|
|
|
@ -22,7 +22,14 @@ logger = logging.getLogger(__name__)
|
|||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||
for p in vector_io_providers:
|
||||
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::chromadb"]:
|
||||
if p.provider_type in [
|
||||
"inline::faiss",
|
||||
"inline::sqlite-vec",
|
||||
"inline::milvus",
|
||||
"inline::chromadb",
|
||||
"remote::pgvector",
|
||||
"remote::chromadb",
|
||||
]:
|
||||
return
|
||||
|
||||
pytest.skip("OpenAI vector stores are not supported by any provider")
|
||||
|
@ -31,12 +38,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
|||
def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models):
|
||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||
for p in vector_io_providers:
|
||||
if p.provider_type in [
|
||||
"inline::faiss",
|
||||
"inline::sqlite-vec",
|
||||
"inline::milvus",
|
||||
"remote::pgvector",
|
||||
]:
|
||||
if p.provider_type in []:
|
||||
return
|
||||
|
||||
pytest.skip("OpenAI vector stores are not supported by any provider")
|
||||
|
|
|
@ -8,6 +8,7 @@ import random
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from chromadb import PersistentClient
|
||||
from pymilvus import MilvusClient, connections
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
|
@ -18,7 +19,7 @@ from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, Faiss
|
|||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
|
@ -26,6 +27,11 @@ COLLECTION_PREFIX = "test_collection"
|
|||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db_id() -> str:
|
||||
return f"test-vector-db-{random.randint(1, 100)}"
|
||||
|
@ -94,11 +100,6 @@ def sample_embeddings_with_metadata(sample_chunks_with_metadata):
|
|||
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_inference_api(embedding_dimension):
|
||||
class MockInferenceAPI:
|
||||
|
@ -246,10 +247,10 @@ def chroma_vec_db_path(tmp_path_factory):
|
|||
|
||||
@pytest.fixture
|
||||
async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
|
||||
index = ChromaIndex(
|
||||
embedding_dimension=embedding_dimension,
|
||||
persist_directory=chroma_vec_db_path,
|
||||
)
|
||||
client = PersistentClient(path=chroma_vec_db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
collection = await maybe_await(client.get_or_create_collection(name))
|
||||
index = ChromaIndex(client=client, collection=collection)
|
||||
await index.initialize()
|
||||
yield index
|
||||
await index.delete()
|
||||
|
@ -257,7 +258,10 @@ async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
|
|||
|
||||
@pytest.fixture
|
||||
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
config = ChromaVectorIOConfig(persist_directory=chroma_vec_db_path)
|
||||
config = ChromaVectorIOConfig(
|
||||
db_path=chroma_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = ChromaVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue