mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
addressed the PR comments
This commit is contained in:
parent
ba8044d243
commit
a138b48eef
3 changed files with 20 additions and 29 deletions
|
@ -42,11 +42,7 @@ class MemoryClient(Memory):
|
||||||
params={
|
params={
|
||||||
"bank_id": bank_id,
|
"bank_id": bank_id,
|
||||||
},
|
},
|
||||||
headers={
|
headers={"Content-Type": "application/json"},
|
||||||
"Content-Type": "application/json",
|
|
||||||
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
|
|
||||||
"weaviate_cluster_url": "http://localhost:8080"}),
|
|
||||||
},
|
|
||||||
timeout=20,
|
timeout=20,
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
@ -69,11 +65,7 @@ class MemoryClient(Memory):
|
||||||
"config": config.dict(),
|
"config": config.dict(),
|
||||||
"url": url,
|
"url": url,
|
||||||
},
|
},
|
||||||
headers={
|
headers={"Content-Type": "application/json"},
|
||||||
"Content-Type": "application/json",
|
|
||||||
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
|
|
||||||
"weaviate_cluster_url": "http://localhost:8080"}),
|
|
||||||
},
|
|
||||||
timeout=20,
|
timeout=20,
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
@ -94,11 +86,7 @@ class MemoryClient(Memory):
|
||||||
"bank_id": bank_id,
|
"bank_id": bank_id,
|
||||||
"documents": [d.dict() for d in documents],
|
"documents": [d.dict() for d in documents],
|
||||||
},
|
},
|
||||||
headers={
|
headers={"Content-Type": "application/json"},
|
||||||
"Content-Type": "application/json",
|
|
||||||
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
|
|
||||||
"weaviate_cluster_url": "http://localhost:8080"}),
|
|
||||||
},
|
|
||||||
timeout=20,
|
timeout=20,
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
@ -117,11 +105,7 @@ class MemoryClient(Memory):
|
||||||
"query": query,
|
"query": query,
|
||||||
"params": params,
|
"params": params,
|
||||||
},
|
},
|
||||||
headers={
|
headers={"Content-Type": "application/json"},
|
||||||
"Content-Type": "application/json",
|
|
||||||
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234",
|
|
||||||
"weaviate_cluster_url": "http://localhost:8080"}),
|
|
||||||
},
|
|
||||||
timeout=20,
|
timeout=20,
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
from .config import WeaviateConfig
|
||||||
|
|
||||||
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
||||||
from .weaviate import WeaviateMemoryAdapter
|
from .weaviate import WeaviateMemoryAdapter
|
||||||
|
|
||||||
impl = WeaviateMemoryAdapter(config.url, config.username, config.password)
|
impl = WeaviateMemoryAdapter(config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
|
@ -75,24 +75,24 @@ class WeaviateMemoryAdapter(Memory):
|
||||||
self.client = None
|
self.client = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize_client(self) -> weaviate.Client:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
request_provider_data = get_request_provider_data()
|
request_provider_data = get_request_provider_data()
|
||||||
if request_provider_data is not None:
|
if request_provider_data is not None:
|
||||||
assert isinstance(request_provider_data, WeaviateRequestProviderData)
|
assert isinstance(request_provider_data, WeaviateRequestProviderData)
|
||||||
|
|
||||||
print(f"WEAVIATE API KEY: {request_provider_data.weaviate_api_key}")
|
print(f"WEAVIATE API KEY: {request_provider_data.weaviate_api_key}")
|
||||||
print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}")
|
print(f"WEAVIATE CLUSTER URL: {request_provider_data.weaviate_cluster_url}")
|
||||||
|
|
||||||
# Connect to Weaviate Cloud
|
# Connect to Weaviate Cloud
|
||||||
self.client = weaviate.connect_to_weaviate_cloud(
|
client = weaviate.connect_to_weaviate_cloud(
|
||||||
cluster_url = request_provider_data.weaviate_cluster_url,
|
cluster_url = request_provider_data.weaviate_cluster_url,
|
||||||
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
|
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create collection if it doesn't exist
|
# Create collection if it doesn't exist
|
||||||
if not self.client.collections.exists(self.config.collection):
|
if not client.collections.exists(self.config.collection):
|
||||||
self.client.collections.create(
|
client.collections.create(
|
||||||
name = self.config.collection,
|
name = self.config.collection,
|
||||||
vectorizer_config = wvc.config.Configure.Vectorizer.none(),
|
vectorizer_config = wvc.config.Configure.Vectorizer.none(),
|
||||||
properties=[
|
properties=[
|
||||||
|
@ -103,12 +103,16 @@ class WeaviateMemoryAdapter(Memory):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise RuntimeError("Could not connect to Weaviate server") from e
|
raise RuntimeError("Could not connect to Weaviate server") from e
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
self.client = self.initialize_client()
|
||||||
|
|
||||||
if self.client:
|
if self.client:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
|
||||||
|
@ -125,7 +129,7 @@ class WeaviateMemoryAdapter(Memory):
|
||||||
config=config,
|
config=config,
|
||||||
url=url,
|
url=url,
|
||||||
)
|
)
|
||||||
|
self.client = self.initialize_client()
|
||||||
# Store the bank as a new collection in Weaviate
|
# Store the bank as a new collection in Weaviate
|
||||||
self.client.collections.create(
|
self.client.collections.create(
|
||||||
name=bank_id
|
name=bank_id
|
||||||
|
@ -145,6 +149,9 @@ class WeaviateMemoryAdapter(Memory):
|
||||||
return bank_index.bank
|
return bank_index.bank
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
|
|
||||||
|
self.client = self.initialize_client()
|
||||||
|
|
||||||
if bank_id in self.cache:
|
if bank_id in self.cache:
|
||||||
return self.cache[bank_id]
|
return self.cache[bank_id]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue