diff --git a/llama_stack/apis/memory/client.py b/llama_stack/apis/memory/client.py index 1349df0a3..5cfed8518 100644 --- a/llama_stack/apis/memory/client.py +++ b/llama_stack/apis/memory/client.py @@ -83,7 +83,13 @@ async def run_main(host: str, port: int, stream: bool): overlap_size_in_tokens=64, ) await banks_client.register_memory_bank( - bank.identifier, bank.memory_bank_type, provider_resource_id=bank.identifier + bank.identifier, + VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + provider_resource_id=bank.identifier, ) retrieved_bank = await banks_client.get_memory_bank(bank.identifier) diff --git a/llama_stack/apis/memory_banks/client.py b/llama_stack/apis/memory_banks/client.py index 6a5bef325..8d137e43c 100644 --- a/llama_stack/apis/memory_banks/client.py +++ b/llama_stack/apis/memory_banks/client.py @@ -58,7 +58,7 @@ class MemoryBanksClient(MemoryBanks): async def register_memory_bank( self, memory_bank_id: str, - memory_bank_type: MemoryBankType, + params: BankParams, provider_resource_id: Optional[str] = None, provider_id: Optional[str] = None, ) -> None: @@ -67,9 +67,9 @@ class MemoryBanksClient(MemoryBanks): f"{self.base_url}/memory_banks/register", json={ "memory_bank_id": memory_bank_id, - "memory_bank_type": memory_bank_type.value, "provider_resource_id": provider_resource_id, "provider_id": provider_id, + "params": params.dict(), }, headers={"Content-Type": "application/json"}, )