Refactor Weaviate config to include cluster URL in memory adapter

This commit is contained in:
Zain Hasan 2024-09-23 21:13:24 -04:00
parent 2fc9bd95a6
commit ba8044d243
3 changed files with 11 additions and 8 deletions

View file

@ -44,7 +44,8 @@ class MemoryClient(Memory):
},
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}),
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
},
timeout=20,
)
@ -70,7 +71,8 @@ class MemoryClient(Memory):
},
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}),
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
},
timeout=20,
)
@ -94,7 +96,8 @@ class MemoryClient(Memory):
},
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}),
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
},
timeout=20,
)
@ -116,7 +119,8 @@ class MemoryClient(Memory):
},
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}),
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
},
timeout=20,
)

View file

@ -11,8 +11,8 @@ class WeaviateRequestProviderData(BaseModel):
# if there _is_ provider data, it must specify the API KEY
# if you want it to be optional, use Optional[str]
weaviate_api_key: str
weaviate_cluster_url: str
@json_schema_type
class WeaviateConfig(BaseModel):
url: str = Field(default="http://localhost:8080")
collection: str = Field(default="MemoryBank")

View file

@ -1,7 +1,6 @@
import json
import uuid
from typing import List, Optional, Dict, Any
from urllib.parse import urlparse
from numpy.typing import NDArray
import weaviate
@ -72,7 +71,6 @@ class WeaviateIndex(EmbeddingIndex):
class WeaviateMemoryAdapter(Memory):
def __init__(self, config: WeaviateConfig) -> None:
print(f"Initializing WeaviateMemoryAdapter with URL: {config.url}")
self.config = config
self.client = None
self.cache = {}
@ -85,9 +83,10 @@ class WeaviateMemoryAdapter(Memory):
assert isinstance(request_provider_data, WeaviateRequestProviderData)
print(f"WEAVIATE API KEY: {request_provider_data.weaviate_api_key}")
print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}")
# Connect to Weaviate Cloud
self.client = weaviate.connect_to_weaviate_cloud(
cluster_url = self.config.url,
cluster_url = request_provider_data.weaviate_cluster_url,
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
)