From a05599c67aeeb3466dbb18b529256c0468c6fcfe Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 6 Oct 2024 22:50:34 -0700 Subject: [PATCH] Weaviate "should" work (i.e., is code-complete) but not tested --- .../adapters/memory/weaviate/weaviate.py | 52 +++++++------------ 1 file changed, 18 insertions(+), 34 deletions(-) diff --git a/llama_stack/providers/adapters/memory/weaviate/weaviate.py b/llama_stack/providers/adapters/memory/weaviate/weaviate.py index 9f8e93434..573802c84 100644 --- a/llama_stack/providers/adapters/memory/weaviate/weaviate.py +++ b/llama_stack/providers/adapters/memory/weaviate/weaviate.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json from typing import Any, Dict, List, Optional import weaviate @@ -23,9 +22,9 @@ from .config import WeaviateConfig, WeaviateRequestProviderData class WeaviateIndex(EmbeddingIndex): - def __init__(self, client: weaviate.Client, collection: str): + def __init__(self, client: weaviate.Client, collection_name: str): self.client = client - self.collection = collection + self.collection_name = collection_name async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): assert len(chunks) == len( @@ -44,17 +43,13 @@ class WeaviateIndex(EmbeddingIndex): ) # Inserting chunks into a prespecified Weaviate collection - assert self.collection is not None, "Collection name must be specified" - my_collection = self.client.collections.get(self.collection) - - await my_collection.data.insert_many(data_objects) + collection = self.client.collections.get(self.collection_name) + await collection.data.insert_many(data_objects) async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: - assert self.collection is not None, "Collection name must be specified" + collection = self.client.collections.get(self.collection_name) - my_collection = self.client.collections.get(self.collection) - - results = my_collection.query.near_vector( + results = collection.query.near_vector( near_vector=embedding.tolist(), limit=k, return_meta_data=wvc.query.MetadataQuery(distance=True), @@ -63,16 +58,9 @@ class WeaviateIndex(EmbeddingIndex): chunks = [] scores = [] for doc in results.objects: - try: - chunk = doc.properties["chunk_content"] - chunks.append(chunk) - scores.append(1.0 / doc.metadata.distance) - - except Exception as e: - import traceback - - traceback.print_exc() - print(f"Failed to parse document: {e}") + chunk = doc.properties["chunk_content"] + chunks.append(chunk) + scores.append(1.0 / doc.metadata.distance) return QueryDocumentsResponse(chunks=chunks, scores=scores) @@ -131,7 +119,7 @@ class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData): index = BankWithIndex( bank=memory_bank, - index=WeaviateIndex(client=client, collection=memory_bank.identifier), + index=WeaviateIndex(client=client, collection_name=memory_bank.identifier), ) self.cache[bank_id] = index @@ -144,19 +132,15 @@ class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData): raise ValueError(f"Bank {bank_id} not found") client = await self._get_client() - collections = await client.collections.list_all().keys() + if not client.collections.exists(bank_id): + raise ValueError(f"Collection with name `{bank_id}` not found") - for collection in collections: - if collection == bank_id: - bank = MemoryBank(**json.loads(collection.metadata["bank"])) - index = BankWithIndex( - bank=bank, - index=WeaviateIndex(self.client, collection), - ) - self.cache[bank_id] = index - return index - - return None + index = BankWithIndex( + bank=bank, + index=WeaviateIndex(client=client, collection_name=bank_id), + ) + self.cache[bank_id] = index + return index async def insert_documents( self,