weaviate fixes

This commit is contained in:
Dinesh Yeduguru 2024-12-10 16:45:33 -08:00
parent 0e451525e5
commit b509d59dcd
2 changed files with 11 additions and 4 deletions

View file

@ -12,10 +12,11 @@ import weaviate
import weaviate.classes as wvc import weaviate.classes as wvc
from numpy.typing import NDArray from numpy.typing import NDArray
from weaviate.classes.init import Auth from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
@ -80,6 +81,12 @@ class WeaviateIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: List[str]) -> None:
collection = self.client.collections.get(self.collection_name)
collection.data.delete_many(
where=Filter.by_property("id").contains_any(chunk_ids)
)
class WeaviateMemoryAdapter( class WeaviateMemoryAdapter(
Memory, Memory,
@ -120,7 +127,7 @@ class WeaviateMemoryAdapter(
memory_bank: MemoryBank, memory_bank: MemoryBank,
) -> None: ) -> None:
assert ( assert (
memory_bank.memory_bank_type == MemoryBankType.vector memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}" ), f"Only vector banks are supported {memory_bank.memory_bank_type}"
client = self._get_client() client = self._get_client()

View file

@ -169,13 +169,13 @@ class TestMemory:
# Test case 5: Query with threshold on similarity score # Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.2} params5 = {"score_threshold": 0.01}
response5 = await memory_impl.query_documents( response5 = await memory_impl.query_documents(
registered_bank.memory_bank_id, query5, params5 registered_bank.memory_bank_id, query5, params5
) )
assert_valid_response(response5) assert_valid_response(response5)
print("The scores are:", response5.scores) print("The scores are:", response5.scores)
assert all(score >= 0.2 for score in response5.scores) assert all(score >= 0.01 for score in response5.scores)
def assert_valid_response(response: QueryDocumentsResponse): def assert_valid_response(response: QueryDocumentsResponse):