fix: update Weaviate fixtures in conftest.py and improve vector DB handling

This commit is contained in:
ChristianZaccaria 2025-08-29 17:29:50 +01:00
parent 4541b517c8
commit f9794f8475
2 changed files with 49 additions and 54 deletions

View file

@ -48,7 +48,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
class WeaviateIndex(EmbeddingIndex): class WeaviateIndex(EmbeddingIndex):
def __init__( def __init__(
self, self,
client: weaviate.Client, client: weaviate.WeaviateClient,
collection_name: str, collection_name: str,
kvstore: KVStore | None = None, kvstore: KVStore | None = None,
): ):
@ -65,14 +65,14 @@ class WeaviateIndex(EmbeddingIndex):
) )
data_objects = [] data_objects = []
for i, chunk in enumerate(chunks): for chunk, embedding in zip(chunks, embeddings, strict=False):
data_objects.append( data_objects.append(
wvc.data.DataObject( wvc.data.DataObject(
properties={ properties={
"chunk_id": chunk.chunk_id, "chunk_id": chunk.chunk_id,
"chunk_content": chunk.model_dump_json(), "chunk_content": chunk.model_dump_json(),
}, },
vector=embeddings[i].tolist(), vector=embedding.tolist(),
) )
) )
@ -346,7 +346,7 @@ class WeaviateVectorIOAdapter(
], ],
) )
self.cache[sanitized_collection_name] = VectorDBWithIndex( self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db, vector_db,
WeaviateIndex(client=client, collection_name=sanitized_collection_name), WeaviateIndex(client=client, collection_name=sanitized_collection_name),
self.inference_api, self.inference_api,
@ -355,32 +355,34 @@ class WeaviateVectorIOAdapter(
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_db(self, vector_db_id: str) -> None:
client = self._get_client() client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False: if vector_db_id not in self.cache or client.collections.exists(sanitized_collection_name) is False:
log.warning(f"Vector DB {sanitized_collection_name} not found")
return return
client.collections.delete(sanitized_collection_name) client.collections.delete(sanitized_collection_name)
await self.cache[sanitized_collection_name].index.delete() await self.cache[vector_db_id].index.delete()
del self.cache[sanitized_collection_name] del self.cache[vector_db_id]
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) if vector_db_id in self.cache:
if sanitized_collection_name in self.cache: return self.cache[vector_db_id]
return self.cache[sanitized_collection_name]
vector_db = await self.vector_db_store.get_vector_db(sanitized_collection_name) if self.vector_db_store is None:
raise VectorStoreNotFoundError(vector_db_id)
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db: if not vector_db:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_db_id)
client = self._get_client() client = self._get_client()
if not client.collections.exists(vector_db.identifier): sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
if not client.collections.exists(sanitized_collection_name):
raise ValueError(f"Collection with name `{sanitized_collection_name}` not found") raise ValueError(f"Collection with name `{sanitized_collection_name}` not found")
index = VectorDBWithIndex( index = VectorDBWithIndex(
vector_db=vector_db, vector_db=vector_db,
index=WeaviateIndex(client=client, collection_name=sanitized_collection_name), index=WeaviateIndex(client=client, collection_name=vector_db.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[sanitized_collection_name] = index self.cache[vector_db_id] = index
return index return index
async def insert_chunks( async def insert_chunks(
@ -389,8 +391,7 @@ class WeaviateVectorIOAdapter(
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
) -> None: ) -> None:
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_db_id)
@ -402,17 +403,15 @@ class WeaviateVectorIOAdapter(
query: InterleavedContent, query: InterleavedContent,
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_db_id)
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True) index = await self._get_and_cache_vector_db_index(store_id)
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
if not index: if not index:
raise ValueError(f"Vector DB {sanitized_collection_name} not found") raise ValueError(f"Vector DB {store_id} not found")
await index.index.delete_chunks(chunks_for_deletion) await index.index.delete_chunks(chunks_for_deletion)

View file

@ -448,45 +448,28 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
yield adapter yield adapter
await adapter.shutdown() await adapter.shutdown()
def weaviate_vec_db_path():
return "localhost:8080"
@pytest.fixture(scope="session")
def weaviate_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db")
return db_path
@pytest.fixture @pytest.fixture
async def weaviate_vec_index(weaviate_vec_db_path, embedding_dimension): async def weaviate_vec_index(weaviate_vec_db_path, embedding_dimension):
import uuid import pytest_socket
import weaviate import weaviate
# Connect to local Weaviate instance pytest_socket.enable_socket()
client = weaviate.connect_to_local( client = weaviate.connect_to_embedded(
host="localhost", hostname="localhost",
port=8080, port=8080,
grpc_port=50051,
persistence_data_path=weaviate_vec_db_path,
) )
index = WeaviateIndex(client=client, collection_name="Testcollection")
collection_name = f"{COLLECTION_PREFIX}_{uuid.uuid4()}" await index.initialize()
index = WeaviateIndex(client=client, collection_name=collection_name)
# Create the collection for this test
import weaviate.classes as wvc
from weaviate.collections.classes.config import _CollectionConfig
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
sanitized_name = sanitize_collection_name(collection_name, weaviate_format=True)
collection_config = _CollectionConfig(
name=sanitized_name,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
],
)
if not client.collections.exists(sanitized_name):
client.collections.create_from_config(collection_config)
yield index yield index
await index.delete() await index.delete()
client.close() client.close()
@ -494,8 +477,20 @@ async def weaviate_vec_index(weaviate_vec_db_path, embedding_dimension):
@pytest.fixture @pytest.fixture
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension): async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
import pytest_socket
import weaviate
pytest_socket.enable_socket()
client = weaviate.connect_to_embedded(
hostname="localhost",
port=8080,
grpc_port=50051,
persistence_data_path=weaviate_vec_db_path,
)
config = WeaviateVectorIOConfig( config = WeaviateVectorIOConfig(
weaviate_cluster_url=weaviate_vec_db_path, weaviate_cluster_url="localhost:8080",
weaviate_api_key=None, weaviate_api_key=None,
kvstore=SqliteKVStoreConfig(), kvstore=SqliteKVStoreConfig(),
) )
@ -517,6 +512,7 @@ async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embeddi
adapter.test_collection_id = collection_id adapter.test_collection_id = collection_id
yield adapter yield adapter
await adapter.shutdown() await adapter.shutdown()
client.close()
@pytest.fixture @pytest.fixture