diff --git a/llama_stack/providers/adapters/memory/weaviate/weaviate.py b/llama_stack/providers/adapters/memory/weaviate/weaviate.py index cd03a8618..ecd844551 100644 --- a/llama_stack/providers/adapters/memory/weaviate/weaviate.py +++ b/llama_stack/providers/adapters/memory/weaviate/weaviate.py @@ -75,9 +75,9 @@ class WeaviateMemoryAdapter(Memory): self.client = None self.cache = {} - async def initialize_client(self) -> weaviate.Client: - try: + def _get_client(self) -> weaviate.Client: request_provider_data = get_request_provider_data() + if request_provider_data is not None: assert isinstance(request_provider_data, WeaviateRequestProviderData) @@ -85,14 +85,18 @@ class WeaviateMemoryAdapter(Memory): print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}") # Connect to Weaviate Cloud - client = weaviate.connect_to_weaviate_cloud( + return weaviate.connect_to_weaviate_cloud( cluster_url = request_provider_data.weaviate_cluster_url, auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key), ) + async def initialize(self) -> None: + try: + self.client = self._get_client() + # Create collection if it doesn't exist - if not client.collections.exists(self.config.collection): - client.collections.create( + if not self.client.collections.exists(self.config.collection): + self.client.collections.create( name = self.config.collection, vectorizer_config = wvc.config.Configure.Vectorizer.none(), properties=[ @@ -102,8 +106,6 @@ class WeaviateMemoryAdapter(Memory): ), ] ) - - return client except Exception as e: import traceback @@ -111,7 +113,7 @@ class WeaviateMemoryAdapter(Memory): raise RuntimeError("Could not connect to Weaviate server") from e async def shutdown(self) -> None: - self.client = self.initialize_client() + self.client = self._get_client() if self.client: self.client.close() @@ -129,7 +131,8 @@ class WeaviateMemoryAdapter(Memory): config=config, url=url, ) - self.client = self.initialize_client() + self.client = self._get_client() + # Store the bank as a new collection in Weaviate self.client.collections.create( name=bank_id @@ -150,7 +153,7 @@ class WeaviateMemoryAdapter(Memory): async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: - self.client = self.initialize_client() + self.client = self._get_client() if bank_id in self.cache: return self.cache[bank_id] @@ -173,7 +176,6 @@ class WeaviateMemoryAdapter(Memory): self, bank_id: str, documents: List[MemoryBankDocument], - ttl_seconds: Optional[int] = None, ) -> None: index = await self._get_and_cache_bank_index(bank_id) if not index: