Weaviate "should" work (i.e., is code-complete) but not tested

This commit is contained in:
Ashwin Bharambe 2024-10-06 22:50:34 -07:00 committed by Ashwin Bharambe
parent 118c0ef105
commit a05599c67a

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import weaviate import weaviate
@ -23,9 +22,9 @@ from .config import WeaviateConfig, WeaviateRequestProviderData
class WeaviateIndex(EmbeddingIndex): class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection: str): def __init__(self, client: weaviate.Client, collection_name: str):
self.client = client self.client = client
self.collection = collection self.collection_name = collection_name
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len( assert len(chunks) == len(
@ -44,17 +43,13 @@ class WeaviateIndex(EmbeddingIndex):
) )
# Inserting chunks into a prespecified Weaviate collection # Inserting chunks into a prespecified Weaviate collection
assert self.collection is not None, "Collection name must be specified" collection = self.client.collections.get(self.collection_name)
my_collection = self.client.collections.get(self.collection) await collection.data.insert_many(data_objects)
await my_collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
assert self.collection is not None, "Collection name must be specified" collection = self.client.collections.get(self.collection_name)
my_collection = self.client.collections.get(self.collection) results = collection.query.near_vector(
results = my_collection.query.near_vector(
near_vector=embedding.tolist(), near_vector=embedding.tolist(),
limit=k, limit=k,
return_meta_data=wvc.query.MetadataQuery(distance=True), return_meta_data=wvc.query.MetadataQuery(distance=True),
@ -63,16 +58,9 @@ class WeaviateIndex(EmbeddingIndex):
chunks = [] chunks = []
scores = [] scores = []
for doc in results.objects: for doc in results.objects:
try: chunk = doc.properties["chunk_content"]
chunk = doc.properties["chunk_content"] chunks.append(chunk)
chunks.append(chunk) scores.append(1.0 / doc.metadata.distance)
scores.append(1.0 / doc.metadata.distance)
except Exception as e:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {e}")
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
@ -131,7 +119,7 @@ class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData):
index = BankWithIndex( index = BankWithIndex(
bank=memory_bank, bank=memory_bank,
index=WeaviateIndex(client=client, collection=memory_bank.identifier), index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
) )
self.cache[bank_id] = index self.cache[bank_id] = index
@ -144,19 +132,15 @@ class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData):
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Bank {bank_id} not found")
client = await self._get_client() client = await self._get_client()
collections = await client.collections.list_all().keys() if not client.collections.exists(bank_id):
raise ValueError(f"Collection with name `{bank_id}` not found")
for collection in collections: index = BankWithIndex(
if collection == bank_id: bank=bank,
bank = MemoryBank(**json.loads(collection.metadata["bank"])) index=WeaviateIndex(client=client, collection_name=bank_id),
index = BankWithIndex( )
bank=bank, self.cache[bank_id] = index
index=WeaviateIndex(self.client, collection), return index
)
self.cache[bank_id] = index
return index
return None
async def insert_documents( async def insert_documents(
self, self,