Refactor WeaviateMemoryAdapter initialization and client handling

This commit is contained in:
Zain Hasan 2024-09-24 19:11:32 -04:00
parent af1710af75
commit aca4a4d7fc

View file

@ -75,9 +75,9 @@ class WeaviateMemoryAdapter(Memory):
self.client = None
self.cache = {}
async def initialize_client(self) -> weaviate.Client:
try:
def _get_client(self) -> weaviate.Client:
request_provider_data = get_request_provider_data()
if request_provider_data is not None:
assert isinstance(request_provider_data, WeaviateRequestProviderData)
@ -85,14 +85,18 @@ class WeaviateMemoryAdapter(Memory):
print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}")
# Connect to Weaviate Cloud
client = weaviate.connect_to_weaviate_cloud(
return weaviate.connect_to_weaviate_cloud(
cluster_url = request_provider_data.weaviate_cluster_url,
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
)
async def initialize(self) -> None:
try:
self.client = self._get_client()
# Create collection if it doesn't exist
if not client.collections.exists(self.config.collection):
client.collections.create(
if not self.client.collections.exists(self.config.collection):
self.client.collections.create(
name = self.config.collection,
vectorizer_config = wvc.config.Configure.Vectorizer.none(),
properties=[
@ -102,8 +106,6 @@ class WeaviateMemoryAdapter(Memory):
),
]
)
return client
except Exception as e:
import traceback
@ -111,7 +113,7 @@ class WeaviateMemoryAdapter(Memory):
raise RuntimeError("Could not connect to Weaviate server") from e
async def shutdown(self) -> None:
self.client = self.initialize_client()
self.client = self._get_client()
if self.client:
self.client.close()
@ -129,7 +131,8 @@ class WeaviateMemoryAdapter(Memory):
config=config,
url=url,
)
self.client = self.initialize_client()
self.client = self._get_client()
# Store the bank as a new collection in Weaviate
self.client.collections.create(
name=bank_id
@ -150,7 +153,7 @@ class WeaviateMemoryAdapter(Memory):
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
self.client = self.initialize_client()
self.client = self._get_client()
if bank_id in self.cache:
return self.cache[bank_id]
@ -173,7 +176,6 @@ class WeaviateMemoryAdapter(Memory):
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index: