From f9794f847550d7d32298c50b2fc59b127e69a92d Mon Sep 17 00:00:00 2001 From: ChristianZaccaria Date: Fri, 29 Aug 2025 17:29:50 +0100 Subject: [PATCH] fix: update Weaviate fixtures in conftest.py and improve vector DB handling --- .../remote/vector_io/weaviate/weaviate.py | 43 +++++++------ tests/unit/providers/vector_io/conftest.py | 60 +++++++++---------- 2 files changed, 49 insertions(+), 54 deletions(-) diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 06bfdf397..02ab9d7c9 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -48,7 +48,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten class WeaviateIndex(EmbeddingIndex): def __init__( self, - client: weaviate.Client, + client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None, ): @@ -65,14 +65,14 @@ class WeaviateIndex(EmbeddingIndex): ) data_objects = [] - for i, chunk in enumerate(chunks): + for chunk, embedding in zip(chunks, embeddings, strict=False): data_objects.append( wvc.data.DataObject( properties={ "chunk_id": chunk.chunk_id, "chunk_content": chunk.model_dump_json(), }, - vector=embeddings[i].tolist(), + vector=embedding.tolist(), ) ) @@ -346,7 +346,7 @@ class WeaviateVectorIOAdapter( ], ) - self.cache[sanitized_collection_name] = VectorDBWithIndex( + self.cache[vector_db.identifier] = VectorDBWithIndex( vector_db, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api, @@ -355,32 +355,34 @@ class WeaviateVectorIOAdapter( async def unregister_vector_db(self, vector_db_id: str) -> None: client = self._get_client() sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) - if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False: - log.warning(f"Vector DB {sanitized_collection_name} not found") + if vector_db_id not in self.cache or client.collections.exists(sanitized_collection_name) is False: return client.collections.delete(sanitized_collection_name) - await self.cache[sanitized_collection_name].index.delete() - del self.cache[sanitized_collection_name] + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: - sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) - if sanitized_collection_name in self.cache: - return self.cache[sanitized_collection_name] + if vector_db_id in self.cache: + return self.cache[vector_db_id] - vector_db = await self.vector_db_store.get_vector_db(sanitized_collection_name) + if self.vector_db_store is None: + raise VectorStoreNotFoundError(vector_db_id) + + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) if not vector_db: raise VectorStoreNotFoundError(vector_db_id) client = self._get_client() - if not client.collections.exists(vector_db.identifier): + sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True) + if not client.collections.exists(sanitized_collection_name): raise ValueError(f"Collection with name `{sanitized_collection_name}` not found") index = VectorDBWithIndex( vector_db=vector_db, - index=WeaviateIndex(client=client, collection_name=sanitized_collection_name), + index=WeaviateIndex(client=client, collection_name=vector_db.identifier), inference_api=self.inference_api, ) - self.cache[sanitized_collection_name] = index + self.cache[vector_db_id] = index return index async def insert_chunks( @@ -389,8 +391,7 @@ class WeaviateVectorIOAdapter( chunks: list[Chunk], ttl_seconds: int | None = None, ) -> None: - sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) - index = await self._get_and_cache_vector_db_index(sanitized_collection_name) + index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise VectorStoreNotFoundError(vector_db_id) @@ -402,17 +403,15 @@ class WeaviateVectorIOAdapter( query: InterleavedContent, params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) - index = await self._get_and_cache_vector_db_index(sanitized_collection_name) + index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise VectorStoreNotFoundError(vector_db_id) return await index.query_chunks(query, params) async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: - sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True) - index = await self._get_and_cache_vector_db_index(sanitized_collection_name) + index = await self._get_and_cache_vector_db_index(store_id) if not index: - raise ValueError(f"Vector DB {sanitized_collection_name} not found") + raise ValueError(f"Vector DB {store_id} not found") await index.index.delete_chunks(chunks_for_deletion) diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index e97cc0822..876e0401d 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -448,45 +448,28 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension): yield adapter await adapter.shutdown() -def weaviate_vec_db_path(): - return "localhost:8080" + + +@pytest.fixture(scope="session") +def weaviate_vec_db_path(tmp_path_factory): + db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db") + return db_path @pytest.fixture async def weaviate_vec_index(weaviate_vec_db_path, embedding_dimension): - import uuid - + import pytest_socket import weaviate - # Connect to local Weaviate instance - client = weaviate.connect_to_local( - host="localhost", + pytest_socket.enable_socket() + client = weaviate.connect_to_embedded( + hostname="localhost", port=8080, + grpc_port=50051, + persistence_data_path=weaviate_vec_db_path, ) - - collection_name = f"{COLLECTION_PREFIX}_{uuid.uuid4()}" - index = WeaviateIndex(client=client, collection_name=collection_name) - - # Create the collection for this test - import weaviate.classes as wvc - from weaviate.collections.classes.config import _CollectionConfig - - from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name - - sanitized_name = sanitize_collection_name(collection_name, weaviate_format=True) - collection_config = _CollectionConfig( - name=sanitized_name, - vectorizer_config=wvc.config.Configure.Vectorizer.none(), - properties=[ - wvc.config.Property( - name="chunk_content", - data_type=wvc.config.DataType.TEXT, - ), - ], - ) - if not client.collections.exists(sanitized_name): - client.collections.create_from_config(collection_config) - + index = WeaviateIndex(client=client, collection_name="Testcollection") + await index.initialize() yield index await index.delete() client.close() @@ -494,8 +477,20 @@ async def weaviate_vec_index(weaviate_vec_db_path, embedding_dimension): @pytest.fixture async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension): + import pytest_socket + import weaviate + + pytest_socket.enable_socket() + + client = weaviate.connect_to_embedded( + hostname="localhost", + port=8080, + grpc_port=50051, + persistence_data_path=weaviate_vec_db_path, + ) + config = WeaviateVectorIOConfig( - weaviate_cluster_url=weaviate_vec_db_path, + weaviate_cluster_url="localhost:8080", weaviate_api_key=None, kvstore=SqliteKVStoreConfig(), ) @@ -517,6 +512,7 @@ async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embeddi adapter.test_collection_id = collection_id yield adapter await adapter.shutdown() + client.close() @pytest.fixture