diff --git a/llama_stack/apis/memory/client.py b/llama_stack/apis/memory/client.py index cb9d7463d..b4bfcb34d 100644 --- a/llama_stack/apis/memory/client.py +++ b/llama_stack/apis/memory/client.py @@ -42,11 +42,7 @@ class MemoryClient(Memory): params={ "bank_id": bank_id, }, - headers={ - "Content-Type": "application/json", - "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234", - "weaviate_cluster_url": "http://localhost:8080"}), - }, + headers={"Content-Type": "application/json"}, timeout=20, ) r.raise_for_status() @@ -69,11 +65,7 @@ class MemoryClient(Memory): "config": config.dict(), "url": url, }, - headers={ - "Content-Type": "application/json", - "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234", - "weaviate_cluster_url": "http://localhost:8080"}), - }, + headers={"Content-Type": "application/json"}, timeout=20, ) r.raise_for_status() @@ -94,11 +86,7 @@ class MemoryClient(Memory): "bank_id": bank_id, "documents": [d.dict() for d in documents], }, - headers={ - "Content-Type": "application/json", - "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234", - "weaviate_cluster_url": "http://localhost:8080"}), - }, + headers={"Content-Type": "application/json"}, timeout=20, ) r.raise_for_status() @@ -117,11 +105,7 @@ class MemoryClient(Memory): "query": query, "params": params, }, - headers={ - "Content-Type": "application/json", - "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234", - "weaviate_cluster_url": "http://localhost:8080"}), - }, + headers={"Content-Type": "application/json"}, timeout=20, ) r.raise_for_status() diff --git a/llama_stack/providers/adapters/memory/weaviate/__init__.py b/llama_stack/providers/adapters/memory/weaviate/__init__.py index 6b7855120..b564eabf4 100644 --- a/llama_stack/providers/adapters/memory/weaviate/__init__.py +++ b/llama_stack/providers/adapters/memory/weaviate/__init__.py @@ -1,8 +1,8 @@ -from llama_stack.distribution.datatypes import RemoteProviderConfig +from .config import WeaviateConfig -async def get_adapter_impl(config: RemoteProviderConfig, _deps): +async def get_adapter_impl(config: WeaviateConfig, _deps): from .weaviate import WeaviateMemoryAdapter - impl = WeaviateMemoryAdapter(config.url, config.username, config.password) + impl = WeaviateMemoryAdapter(config) await impl.initialize() return impl \ No newline at end of file diff --git a/llama_stack/providers/adapters/memory/weaviate/weaviate.py b/llama_stack/providers/adapters/memory/weaviate/weaviate.py index 275722be2..cd03a8618 100644 --- a/llama_stack/providers/adapters/memory/weaviate/weaviate.py +++ b/llama_stack/providers/adapters/memory/weaviate/weaviate.py @@ -75,24 +75,24 @@ class WeaviateMemoryAdapter(Memory): self.client = None self.cache = {} - async def initialize(self) -> None: + async def initialize_client(self) -> weaviate.Client: try: - request_provider_data = get_request_provider_data() if request_provider_data is not None: assert isinstance(request_provider_data, WeaviateRequestProviderData) print(f"WEAVIATE API KEY: {request_provider_data.weaviate_api_key}") print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}") + # Connect to Weaviate Cloud - self.client = weaviate.connect_to_weaviate_cloud( + client = weaviate.connect_to_weaviate_cloud( cluster_url = request_provider_data.weaviate_cluster_url, auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key), ) # Create collection if it doesn't exist - if not self.client.collections.exists(self.config.collection): - self.client.collections.create( + if not client.collections.exists(self.config.collection): + client.collections.create( name = self.config.collection, vectorizer_config = wvc.config.Configure.Vectorizer.none(), properties=[ @@ -102,6 +102,8 @@ class WeaviateMemoryAdapter(Memory): ), ] ) + + return client except Exception as e: import traceback @@ -109,6 +111,8 @@ class WeaviateMemoryAdapter(Memory): raise RuntimeError("Could not connect to Weaviate server") from e async def shutdown(self) -> None: + self.client = self.initialize_client() + if self.client: self.client.close() @@ -125,7 +129,7 @@ class WeaviateMemoryAdapter(Memory): config=config, url=url, ) - + self.client = self.initialize_client() # Store the bank as a new collection in Weaviate self.client.collections.create( name=bank_id @@ -145,6 +149,9 @@ class WeaviateMemoryAdapter(Memory): return bank_index.bank async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: + + self.client = self.initialize_client() + if bank_id in self.cache: return self.cache[bank_id]