Refactor Weaviate config to include cluster URL in memory adapter

This commit is contained in:
Zain Hasan 2024-09-23 21:13:24 -04:00
parent 2fc9bd95a6
commit ba8044d243
3 changed files with 11 additions and 8 deletions

View file

@ -44,7 +44,8 @@ class MemoryClient(Memory):
}, },
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}), "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
}, },
timeout=20, timeout=20,
) )
@ -70,7 +71,8 @@ class MemoryClient(Memory):
}, },
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}), "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
}, },
timeout=20, timeout=20,
) )
@ -94,7 +96,8 @@ class MemoryClient(Memory):
}, },
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}), "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
}, },
timeout=20, timeout=20,
) )
@ -116,7 +119,8 @@ class MemoryClient(Memory):
}, },
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}), "X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
}, },
timeout=20, timeout=20,
) )

View file

@ -11,8 +11,8 @@ class WeaviateRequestProviderData(BaseModel):
# if there _is_ provider data, it must specify the API KEY # if there _is_ provider data, it must specify the API KEY
# if you want it to be optional, use Optional[str] # if you want it to be optional, use Optional[str]
weaviate_api_key: str weaviate_api_key: str
weaviate_cluster_url: str
@json_schema_type @json_schema_type
class WeaviateConfig(BaseModel): class WeaviateConfig(BaseModel):
url: str = Field(default="http://localhost:8080")
collection: str = Field(default="MemoryBank") collection: str = Field(default="MemoryBank")

View file

@ -1,7 +1,6 @@
import json import json
import uuid import uuid
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from urllib.parse import urlparse
from numpy.typing import NDArray from numpy.typing import NDArray
import weaviate import weaviate
@ -72,7 +71,6 @@ class WeaviateIndex(EmbeddingIndex):
class WeaviateMemoryAdapter(Memory): class WeaviateMemoryAdapter(Memory):
def __init__(self, config: WeaviateConfig) -> None: def __init__(self, config: WeaviateConfig) -> None:
print(f"Initializing WeaviateMemoryAdapter with URL: {config.url}")
self.config = config self.config = config
self.client = None self.client = None
self.cache = {} self.cache = {}
@ -85,9 +83,10 @@ class WeaviateMemoryAdapter(Memory):
assert isinstance(request_provider_data, WeaviateRequestProviderData) assert isinstance(request_provider_data, WeaviateRequestProviderData)
print(f"WEAVIATE API KEY: {request_provider_data.weaviate_api_key}") print(f"WEAVIATE API KEY: {request_provider_data.weaviate_api_key}")
print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}")
# Connect to Weaviate Cloud # Connect to Weaviate Cloud
self.client = weaviate.connect_to_weaviate_cloud( self.client = weaviate.connect_to_weaviate_cloud(
cluster_url = self.config.url, cluster_url = request_provider_data.weaviate_cluster_url,
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key), auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
) )