migrate memory banks to Resource and new registration

This commit is contained in:
Dinesh Yeduguru 2024-11-08 15:45:26 -08:00
parent b4416b72fd
commit c82f13bf9e
16 changed files with 178 additions and 104 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import json
from typing import Any, Dict, List, Optional
@ -26,13 +25,13 @@ def deserialize_memory_bank_def(
raise ValueError("Memory bank type not specified")
type = j["type"]
if type == MemoryBankType.vector.value:
return VectorMemoryBankDef(**j)
return VectorMemoryBank(**j)
elif type == MemoryBankType.keyvalue.value:
return KeyValueMemoryBankDef(**j)
return KeyValueMemoryBank(**j)
elif type == MemoryBankType.keyword.value:
return KeywordMemoryBankDef(**j)
return KeywordMemoryBank(**j)
elif type == MemoryBankType.graph.value:
return GraphMemoryBankDef(**j)
return GraphMemoryBank(**j)
else:
raise ValueError(f"Unknown memory bank type: {type}")
@ -47,7 +46,7 @@ class MemoryBanksClient(MemoryBanks):
async def shutdown(self) -> None:
pass
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
async def list_memory_banks(self) -> List[MemoryBank]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/list",
@ -57,13 +56,20 @@ class MemoryBanksClient(MemoryBanks):
return [deserialize_memory_bank_def(x) for x in response.json()]
async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider
self,
memory_bank_id: str,
memory_bank_type: MemoryBankType,
provider_resource_id: Optional[str] = None,
provider_id: Optional[str] = None,
) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/memory_banks/register",
json={
"memory_bank": json.loads(memory_bank.json()),
"memory_bank_id": memory_bank_id,
"memory_bank_type": memory_bank_type.value,
"provider_resource_id": provider_resource_id,
"provider_id": provider_id,
},
headers={"Content-Type": "application/json"},
)
@ -71,13 +77,13 @@ class MemoryBanksClient(MemoryBanks):
async def get_memory_bank(
self,
identifier: str,
) -> Optional[MemoryBankDefWithProvider]:
memory_bank_id: str,
) -> Optional[MemoryBank]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/get",
params={
"identifier": identifier,
"memory_bank_id": memory_bank_id,
},
headers={"Content-Type": "application/json"},
)
@ -94,7 +100,7 @@ async def run_main(host: str, port: int, stream: bool):
# register memory bank for the first time
response = await client.register_memory_bank(
VectorMemoryBankDef(
VectorMemoryBank(
identifier="test_bank2",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,