Make Safety test work, other cleanup

This commit is contained in:
Ashwin Bharambe 2024-10-09 21:09:50 -07:00
parent ba1f294cc6
commit fcd22b6baa
16 changed files with 229 additions and 123 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import json
from typing import Any, Dict, List, Optional
@ -15,7 +16,9 @@ from termcolor import cprint
from .memory_banks import * # noqa: F403
def deserialize_memory_bank_def(j: Optional[Dict[str, Any]]) -> MemoryBankDef:
def deserialize_memory_bank_def(
j: Optional[Dict[str, Any]]
) -> MemoryBankDefWithProvider:
if j is None:
return None
@ -44,7 +47,7 @@ class MemoryBanksClient(MemoryBanks):
async def shutdown(self) -> None:
pass
async def list_memory_banks(self) -> List[MemoryBankDef]:
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/list",
@ -53,10 +56,23 @@ class MemoryBanksClient(MemoryBanks):
response.raise_for_status()
return [deserialize_memory_bank_def(x) for x in response.json()]
async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider
) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/memory_banks/register",
json={
"memory_bank": json.loads(memory_bank.json()),
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
async def get_memory_bank(
self,
identifier: str,
) -> Optional[MemoryBankDef]:
) -> Optional[MemoryBankDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/get",