mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
weaviate fixes
This commit is contained in:
parent
0e451525e5
commit
b509d59dcd
2 changed files with 11 additions and 4 deletions
|
@ -12,10 +12,11 @@ import weaviate
|
||||||
import weaviate.classes as wvc
|
import weaviate.classes as wvc
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from weaviate.classes.init import Auth
|
from weaviate.classes.init import Auth
|
||||||
|
from weaviate.classes.query import Filter
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
|
@ -80,6 +81,12 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def delete(self, chunk_ids: List[str]) -> None:
|
||||||
|
collection = self.client.collections.get(self.collection_name)
|
||||||
|
collection.data.delete_many(
|
||||||
|
where=Filter.by_property("id").contains_any(chunk_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WeaviateMemoryAdapter(
|
class WeaviateMemoryAdapter(
|
||||||
Memory,
|
Memory,
|
||||||
|
@ -120,7 +127,7 @@ class WeaviateMemoryAdapter(
|
||||||
memory_bank: MemoryBank,
|
memory_bank: MemoryBank,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert (
|
assert (
|
||||||
memory_bank.memory_bank_type == MemoryBankType.vector
|
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||||
|
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
|
|
|
@ -169,13 +169,13 @@ class TestMemory:
|
||||||
|
|
||||||
# Test case 5: Query with threshold on similarity score
|
# Test case 5: Query with threshold on similarity score
|
||||||
query5 = "quantum computing" # Not directly related to any document
|
query5 = "quantum computing" # Not directly related to any document
|
||||||
params5 = {"score_threshold": 0.2}
|
params5 = {"score_threshold": 0.01}
|
||||||
response5 = await memory_impl.query_documents(
|
response5 = await memory_impl.query_documents(
|
||||||
registered_bank.memory_bank_id, query5, params5
|
registered_bank.memory_bank_id, query5, params5
|
||||||
)
|
)
|
||||||
assert_valid_response(response5)
|
assert_valid_response(response5)
|
||||||
print("The scores are:", response5.scores)
|
print("The scores are:", response5.scores)
|
||||||
assert all(score >= 0.2 for score in response5.scores)
|
assert all(score >= 0.01 for score in response5.scores)
|
||||||
|
|
||||||
|
|
||||||
def assert_valid_response(response: QueryDocumentsResponse):
|
def assert_valid_response(response: QueryDocumentsResponse):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue