Refactor WeaviateMemoryAdapter initialization and client handling

This commit is contained in:
Zain Hasan 2024-09-24 19:11:32 -04:00
parent af1710af75
commit aca4a4d7fc

View file

@ -75,9 +75,9 @@ class WeaviateMemoryAdapter(Memory):
self.client = None self.client = None
self.cache = {} self.cache = {}
async def initialize_client(self) -> weaviate.Client: def _get_client(self) -> weaviate.Client:
try:
request_provider_data = get_request_provider_data() request_provider_data = get_request_provider_data()
if request_provider_data is not None: if request_provider_data is not None:
assert isinstance(request_provider_data, WeaviateRequestProviderData) assert isinstance(request_provider_data, WeaviateRequestProviderData)
@ -85,14 +85,18 @@ class WeaviateMemoryAdapter(Memory):
print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}") print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}")
# Connect to Weaviate Cloud # Connect to Weaviate Cloud
client = weaviate.connect_to_weaviate_cloud( return weaviate.connect_to_weaviate_cloud(
cluster_url = request_provider_data.weaviate_cluster_url, cluster_url = request_provider_data.weaviate_cluster_url,
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key), auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
) )
async def initialize(self) -> None:
try:
self.client = self._get_client()
# Create collection if it doesn't exist # Create collection if it doesn't exist
if not client.collections.exists(self.config.collection): if not self.client.collections.exists(self.config.collection):
client.collections.create( self.client.collections.create(
name = self.config.collection, name = self.config.collection,
vectorizer_config = wvc.config.Configure.Vectorizer.none(), vectorizer_config = wvc.config.Configure.Vectorizer.none(),
properties=[ properties=[
@ -102,8 +106,6 @@ class WeaviateMemoryAdapter(Memory):
), ),
] ]
) )
return client
except Exception as e: except Exception as e:
import traceback import traceback
@ -111,7 +113,7 @@ class WeaviateMemoryAdapter(Memory):
raise RuntimeError("Could not connect to Weaviate server") from e raise RuntimeError("Could not connect to Weaviate server") from e
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client = self.initialize_client() self.client = self._get_client()
if self.client: if self.client:
self.client.close() self.client.close()
@ -129,7 +131,8 @@ class WeaviateMemoryAdapter(Memory):
config=config, config=config,
url=url, url=url,
) )
self.client = self.initialize_client() self.client = self._get_client()
# Store the bank as a new collection in Weaviate # Store the bank as a new collection in Weaviate
self.client.collections.create( self.client.collections.create(
name=bank_id name=bank_id
@ -150,7 +153,7 @@ class WeaviateMemoryAdapter(Memory):
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
self.client = self.initialize_client() self.client = self._get_client()
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]
@ -173,7 +176,6 @@ class WeaviateMemoryAdapter(Memory):
self, self,
bank_id: str, bank_id: str,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_bank_index(bank_id)
if not index: if not index: