Weaviate "should" work (i.e., is code-complete) but not tested

This commit is contained in:
Ashwin Bharambe 2024-10-06 22:50:34 -07:00 committed by Ashwin Bharambe
parent 118c0ef105
commit a05599c67a

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
import weaviate
@ -23,9 +22,9 @@ from .config import WeaviateConfig, WeaviateRequestProviderData
class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection: str):
def __init__(self, client: weaviate.Client, collection_name: str):
self.client = client
self.collection = collection
self.collection_name = collection_name
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
@ -44,17 +43,13 @@ class WeaviateIndex(EmbeddingIndex):
)
# Inserting chunks into a prespecified Weaviate collection
assert self.collection is not None, "Collection name must be specified"
my_collection = self.client.collections.get(self.collection)
await my_collection.data.insert_many(data_objects)
collection = self.client.collections.get(self.collection_name)
await collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
assert self.collection is not None, "Collection name must be specified"
collection = self.client.collections.get(self.collection_name)
my_collection = self.client.collections.get(self.collection)
results = my_collection.query.near_vector(
results = collection.query.near_vector(
near_vector=embedding.tolist(),
limit=k,
return_meta_data=wvc.query.MetadataQuery(distance=True),
@ -63,17 +58,10 @@ class WeaviateIndex(EmbeddingIndex):
chunks = []
scores = []
for doc in results.objects:
try:
chunk = doc.properties["chunk_content"]
chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance)
except Exception as e:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {e}")
return QueryDocumentsResponse(chunks=chunks, scores=scores)
@ -131,7 +119,7 @@ class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData):
index = BankWithIndex(
bank=memory_bank,
index=WeaviateIndex(client=client, collection=memory_bank.identifier),
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
)
self.cache[bank_id] = index
@ -144,20 +132,16 @@ class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData):
raise ValueError(f"Bank {bank_id} not found")
client = await self._get_client()
collections = await client.collections.list_all().keys()
if not client.collections.exists(bank_id):
raise ValueError(f"Collection with name `{bank_id}` not found")
for collection in collections:
if collection == bank_id:
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(self.client, collection),
index=WeaviateIndex(client=client, collection_name=bank_id),
)
self.cache[bank_id] = index
return index
return None
async def insert_documents(
self,
bank_id: str,