memory bank registration fixes

This commit is contained in:
Ashwin Bharambe 2024-10-06 22:00:54 -07:00 committed by Ashwin Bharambe
parent 099a95b614
commit 3725e74906
8 changed files with 108 additions and 62 deletions

View file

@ -5,8 +5,9 @@
# the root directory of this source tree.
import asyncio
import json
from typing import List, Optional
from typing import Any, Dict, List, Optional
import fire
import httpx
@ -15,6 +16,25 @@ from termcolor import cprint
from .memory_banks import * # noqa: F403
def deserialize_memory_bank_def(j: Optional[Dict[str, Any]]) -> MemoryBankDef:
if j is None:
return None
if "type" not in j:
raise ValueError("Memory bank type not specified")
type = j["type"]
if type == MemoryBankType.vector.value:
return VectorMemoryBankDef(**j)
elif type == MemoryBankType.keyvalue.value:
return KeyValueMemoryBankDef(**j)
elif type == MemoryBankType.keyword.value:
return KeywordMemoryBankDef(**j)
elif type == MemoryBankType.graph.value:
return GraphMemoryBankDef(**j)
else:
raise ValueError(f"Unknown memory bank type: {type}")
class MemoryBanksClient(MemoryBanks):
def __init__(self, base_url: str):
self.base_url = base_url
@ -25,37 +45,57 @@ class MemoryBanksClient(MemoryBanks):
async def shutdown(self) -> None:
pass
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
async def list_memory_banks(self) -> List[MemoryBankDef]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [MemoryBankSpec(**x) for x in response.json()]
return [deserialize_memory_bank_def(x) for x in response.json()]
async def get_serving_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]:
async def get_memory_bank(
self,
identifier: str,
) -> Optional[MemoryBankDef]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/get",
params={
"bank_type": bank_type.value,
"identifier": identifier,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return MemoryBankSpec(**j)
return deserialize_memory_bank_def(j)
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/memory/register_memory_bank",
json={
"memory_bank": json.loads(memory_bank.json()),
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
async def run_main(host: str, port: int, stream: bool):
client = MemoryBanksClient(f"http://{host}:{port}")
response = await client.list_available_memory_banks()
await client.register_memory_bank(
VectorMemoryBankDef(
identifier="test_bank",
provider_id="",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")