# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import asyncio # import json from typing import Dict, List, Optional import fire import httpx # from termcolor import cprint from .api import * # noqa: F403 async def get_client_impl(base_url: str): return MemoryClient(base_url) class MemoryClient(Memory): def __init__(self, base_url: str): print(f"Initializing client for {base_url}") self.base_url = base_url async def initialize(self) -> None: pass async def shutdown(self) -> None: pass async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: async with httpx.AsyncClient() as client: async with client.get( f"{self.base_url}/memory_banks/get", params={ "bank_id": bank_id, }, headers={"Content-Type": "application/json"}, timeout=20, ) as r: r.raise_for_status() d = r.json() if len(d) == 0: return None return MemoryBank(**d) async def create_memory_bank( self, name: str, config: MemoryBankConfig, url: Optional[URL] = None, ) -> MemoryBank: async with httpx.AsyncClient() as client: async with client.post( f"{self.base_url}/memory_banks/create", data={ "name": name, "config": config.dict(), "url": url, }, headers={"Content-Type": "application/json"}, timeout=20, ) as r: r.raise_for_status() d = r.json() if len(d) == 0: return None return MemoryBank(**d) async def insert_documents( self, bank_id: str, documents: List[MemoryBankDocument], ) -> None: async with httpx.AsyncClient() as client: async with client.post( f"{self.base_url}/memory_bank/insert", data={ "bank_id": bank_id, "documents": documents, }, headers={"Content-Type": "application/json"}, timeout=20, ) as r: r.raise_for_status() async def query_documents( self, bank_id: str, query: InterleavedTextMedia, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: async with httpx.AsyncClient() as client: async with client.post( f"{self.base_url}/memory_bank/query", data={ "bank_id": bank_id, "query": query, "params": params, }, headers={"Content-Type": "application/json"}, timeout=20, ) as r: r.raise_for_status() return QueryDocumentsResponse(**r.json()) async def run_main(host: str, port: int, stream: bool): client = MemoryClient(f"http://{host}:{port}") # create a memory bank bank = await client.create_memory_bank( name="test_bank", config=VectorMemoryBankConfig( bank_id="test_bank", embedding_model="dragon-roberta-query-2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), ) print(bank) retrieved_bank = await client.get_memory_bank(bank.bank_id) assert retrieved_bank is not None assert retrieved_bank.embedding_model == "dragon-roberta-query-2" # insert some documents await client.insert_documents( bank_id=bank.bank_id, documents=[ MemoryBankDocument( document_id="1", content="hello world", ), MemoryBankDocument( document_id="2", content="goodbye world", ), ], ) # query the documents response = await client.query_documents( bank_id=bank.bank_id, query=[ "hello world", ], ) print(response) def main(host: str, port: int, stream: bool = True): asyncio.run(run_main(host, port, stream)) if __name__ == "__main__": fire.Fire(main)