From b509d59dcd815e0e153d64f81b6c4a11d1a1af61 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 10 Dec 2024 16:45:33 -0800 Subject: [PATCH] weaviate fixes --- .../providers/remote/memory/weaviate/weaviate.py | 11 +++++++++-- llama_stack/providers/tests/memory/test_memory.py | 4 ++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py index 3fa9ace51..b409b697b 100644 --- a/llama_stack/providers/remote/memory/weaviate/weaviate.py +++ b/llama_stack/providers/remote/memory/weaviate/weaviate.py @@ -12,10 +12,11 @@ import weaviate import weaviate.classes as wvc from numpy.typing import NDArray from weaviate.classes.init import Auth +from weaviate.classes.query import Filter from llama_stack.apis.memory import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, @@ -80,6 +81,12 @@ class WeaviateIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) + async def delete(self, chunk_ids: List[str]) -> None: + collection = self.client.collections.get(self.collection_name) + collection.data.delete_many( + where=Filter.by_property("id").contains_any(chunk_ids) + ) + class WeaviateMemoryAdapter( Memory, @@ -120,7 +127,7 @@ class WeaviateMemoryAdapter( memory_bank: MemoryBank, ) -> None: assert ( - memory_bank.memory_bank_type == MemoryBankType.vector + memory_bank.memory_bank_type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.memory_bank_type}" client = self._get_client() diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index 85f4351f8..526aa646c 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -169,13 +169,13 @@ class TestMemory: # Test case 5: Query with threshold on similarity score query5 = "quantum computing" # Not directly related to any document - params5 = {"score_threshold": 0.2} + params5 = {"score_threshold": 0.01} response5 = await memory_impl.query_documents( registered_bank.memory_bank_id, query5, params5 ) assert_valid_response(response5) print("The scores are:", response5.scores) - assert all(score >= 0.2 for score in response5.scores) + assert all(score >= 0.01 for score in response5.scores) def assert_valid_response(response: QueryDocumentsResponse):