mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
Refactor WeaviateMemoryAdapter initialization and client handling
This commit is contained in:
parent
af1710af75
commit
aca4a4d7fc
1 changed files with 13 additions and 11 deletions
|
@ -75,9 +75,9 @@ class WeaviateMemoryAdapter(Memory):
|
|||
self.client = None
|
||||
self.cache = {}
|
||||
|
||||
async def initialize_client(self) -> weaviate.Client:
|
||||
try:
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
request_provider_data = get_request_provider_data()
|
||||
|
||||
if request_provider_data is not None:
|
||||
assert isinstance(request_provider_data, WeaviateRequestProviderData)
|
||||
|
||||
|
@ -85,14 +85,18 @@ class WeaviateMemoryAdapter(Memory):
|
|||
print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}")
|
||||
|
||||
# Connect to Weaviate Cloud
|
||||
client = weaviate.connect_to_weaviate_cloud(
|
||||
return weaviate.connect_to_weaviate_cloud(
|
||||
cluster_url = request_provider_data.weaviate_cluster_url,
|
||||
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.client = self._get_client()
|
||||
|
||||
# Create collection if it doesn't exist
|
||||
if not client.collections.exists(self.config.collection):
|
||||
client.collections.create(
|
||||
if not self.client.collections.exists(self.config.collection):
|
||||
self.client.collections.create(
|
||||
name = self.config.collection,
|
||||
vectorizer_config = wvc.config.Configure.Vectorizer.none(),
|
||||
properties=[
|
||||
|
@ -102,8 +106,6 @@ class WeaviateMemoryAdapter(Memory):
|
|||
),
|
||||
]
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
@ -111,7 +113,7 @@ class WeaviateMemoryAdapter(Memory):
|
|||
raise RuntimeError("Could not connect to Weaviate server") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client = self.initialize_client()
|
||||
self.client = self._get_client()
|
||||
|
||||
if self.client:
|
||||
self.client.close()
|
||||
|
@ -129,7 +131,8 @@ class WeaviateMemoryAdapter(Memory):
|
|||
config=config,
|
||||
url=url,
|
||||
)
|
||||
self.client = self.initialize_client()
|
||||
self.client = self._get_client()
|
||||
|
||||
# Store the bank as a new collection in Weaviate
|
||||
self.client.collections.create(
|
||||
name=bank_id
|
||||
|
@ -150,7 +153,7 @@ class WeaviateMemoryAdapter(Memory):
|
|||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
|
||||
self.client = self.initialize_client()
|
||||
self.client = self._get_client()
|
||||
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
@ -173,7 +176,6 @@ class WeaviateMemoryAdapter(Memory):
|
|||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
if not index:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue