addressed the PR comments

This commit is contained in:
Zain Hasan 2024-09-24 16:21:37 -04:00
parent ba8044d243
commit a138b48eef
3 changed files with 20 additions and 29 deletions

View file

@ -42,11 +42,7 @@ class MemoryClient(Memory):
params={
"bank_id": bank_id,
},
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
@ -69,11 +65,7 @@ class MemoryClient(Memory):
"config": config.dict(),
"url": url,
},
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
@ -94,11 +86,7 @@ class MemoryClient(Memory):
"bank_id": bank_id,
"documents": [d.dict() for d in documents],
},
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
@ -117,11 +105,7 @@ class MemoryClient(Memory):
"query": query,
"params": params,
},
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
"weaviate_cluster_url": "http://localhost:8080"}),
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()

View file

@ -1,8 +1,8 @@
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .config import WeaviateConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
async def get_adapter_impl(config: WeaviateConfig, _deps):
from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config.url, config.username, config.password)
impl = WeaviateMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -75,24 +75,24 @@ class WeaviateMemoryAdapter(Memory):
self.client = None
self.cache = {}
async def initialize(self) -> None:
async def initialize_client(self) -> weaviate.Client:
try:
request_provider_data = get_request_provider_data()
if request_provider_data is not None:
assert isinstance(request_provider_data, WeaviateRequestProviderData)
print(f"WEAVIATE API KEY: {request_provider_data.weaviate_api_key}")
print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}")
# Connect to Weaviate Cloud
self.client = weaviate.connect_to_weaviate_cloud(
client = weaviate.connect_to_weaviate_cloud(
cluster_url = request_provider_data.weaviate_cluster_url,
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
)
# Create collection if it doesn't exist
if not self.client.collections.exists(self.config.collection):
self.client.collections.create(
if not client.collections.exists(self.config.collection):
client.collections.create(
name = self.config.collection,
vectorizer_config = wvc.config.Configure.Vectorizer.none(),
properties=[
@ -102,6 +102,8 @@ class WeaviateMemoryAdapter(Memory):
),
]
)
return client
except Exception as e:
import traceback
@ -109,6 +111,8 @@ class WeaviateMemoryAdapter(Memory):
raise RuntimeError("Could not connect to Weaviate server") from e
async def shutdown(self) -> None:
self.client = self.initialize_client()
if self.client:
self.client.close()
@ -125,7 +129,7 @@ class WeaviateMemoryAdapter(Memory):
config=config,
url=url,
)
self.client = self.initialize_client()
# Store the bank as a new collection in Weaviate
self.client.collections.create(
name=bank_id
@ -145,6 +149,9 @@ class WeaviateMemoryAdapter(Memory):
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
self.client = self.initialize_client()
if bank_id in self.cache:
return self.cache[bank_id]