addressed the PR comments

This commit is contained in:
Zain Hasan 2024-09-24 16:21:37 -04:00
parent ba8044d243
commit a138b48eef
3 changed files with 20 additions and 29 deletions

View file

@ -42,11 +42,7 @@ class MemoryClient(Memory):
params={ params={
"bank_id": bank_id, "bank_id": bank_id,
}, },
headers={ headers={"Content-Type": "application/json"},
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
},
timeout=20, timeout=20,
) )
r.raise_for_status() r.raise_for_status()
@ -69,11 +65,7 @@ class MemoryClient(Memory):
"config": config.dict(), "config": config.dict(),
"url": url, "url": url,
}, },
headers={ headers={"Content-Type": "application/json"},
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
},
timeout=20, timeout=20,
) )
r.raise_for_status() r.raise_for_status()
@ -94,11 +86,7 @@ class MemoryClient(Memory):
"bank_id": bank_id, "bank_id": bank_id,
"documents": [d.dict() for d in documents], "documents": [d.dict() for d in documents],
}, },
headers={ headers={"Content-Type": "application/json"},
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
},
timeout=20, timeout=20,
) )
r.raise_for_status() r.raise_for_status()
@ -117,11 +105,7 @@ class MemoryClient(Memory):
"query": query, "query": query,
"params": params, "params": params,
}, },
headers={ headers={"Content-Type": "application/json"},
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
},
timeout=20, timeout=20,
) )
r.raise_for_status() r.raise_for_status()

View file

@ -1,8 +1,8 @@
from llama_stack.distribution.datatypes import RemoteProviderConfig from .config import WeaviateConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps): async def get_adapter_impl(config: WeaviateConfig, _deps):
from .weaviate import WeaviateMemoryAdapter from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config.url, config.username, config.password) impl = WeaviateMemoryAdapter(config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -75,24 +75,24 @@ class WeaviateMemoryAdapter(Memory):
self.client = None self.client = None
self.cache = {} self.cache = {}
async def initialize(self) -> None: async def initialize_client(self) -> weaviate.Client:
try: try:
request_provider_data = get_request_provider_data() request_provider_data = get_request_provider_data()
if request_provider_data is not None: if request_provider_data is not None:
assert isinstance(request_provider_data, WeaviateRequestProviderData) assert isinstance(request_provider_data, WeaviateRequestProviderData)
print(f"WEAVIATE API KEY: {request_provider_data.weaviate_api_key}") print(f"WEAVIATE API KEY: {request_provider_data.weaviate_api_key}")
print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}") print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}")
# Connect to Weaviate Cloud # Connect to Weaviate Cloud
self.client = weaviate.connect_to_weaviate_cloud( client = weaviate.connect_to_weaviate_cloud(
cluster_url = request_provider_data.weaviate_cluster_url, cluster_url = request_provider_data.weaviate_cluster_url,
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key), auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
) )
# Create collection if it doesn't exist # Create collection if it doesn't exist
if not self.client.collections.exists(self.config.collection): if not client.collections.exists(self.config.collection):
self.client.collections.create( client.collections.create(
name = self.config.collection, name = self.config.collection,
vectorizer_config = wvc.config.Configure.Vectorizer.none(), vectorizer_config = wvc.config.Configure.Vectorizer.none(),
properties=[ properties=[
@ -102,6 +102,8 @@ class WeaviateMemoryAdapter(Memory):
), ),
] ]
) )
return client
except Exception as e: except Exception as e:
import traceback import traceback
@ -109,6 +111,8 @@ class WeaviateMemoryAdapter(Memory):
raise RuntimeError("Could not connect to Weaviate server") from e raise RuntimeError("Could not connect to Weaviate server") from e
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client = self.initialize_client()
if self.client: if self.client:
self.client.close() self.client.close()
@ -125,7 +129,7 @@ class WeaviateMemoryAdapter(Memory):
config=config, config=config,
url=url, url=url,
) )
self.client = self.initialize_client()
# Store the bank as a new collection in Weaviate # Store the bank as a new collection in Weaviate
self.client.collections.create( self.client.collections.create(
name=bank_id name=bank_id
@ -145,6 +149,9 @@ class WeaviateMemoryAdapter(Memory):
return bank_index.bank return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
self.client = self.initialize_client()
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]