# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import asyncio from typing import Any, Dict, List, Optional import fire import httpx from llama_toolchain.core.datatypes import RemoteProviderConfig from .api import * # noqa: F403 async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory: return MemoryClient(config.url) class MemoryClient(Memory): def __init__(self, base_url: str): self.base_url = base_url async def initialize(self) -> None: pass async def shutdown(self) -> None: pass async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: async with httpx.AsyncClient() as client: r = await client.get( f"{self.base_url}/memory_banks/get", params={ "bank_id": bank_id, }, headers={"Content-Type": "application/json"}, timeout=20, ) r.raise_for_status() d = r.json() if not d: return None return MemoryBank(**d) async def create_memory_bank( self, name: str, config: MemoryBankConfig, url: Optional[URL] = None, ) -> MemoryBank: async with httpx.AsyncClient() as client: r = await client.post( f"{self.base_url}/memory_banks/create", json={ "name": name, "config": config.dict(), "url": url, }, headers={"Content-Type": "application/json"}, timeout=20, ) r.raise_for_status() d = r.json() if not d: return None return MemoryBank(**d) async def insert_documents( self, bank_id: str, documents: List[MemoryBankDocument], ) -> None: async with httpx.AsyncClient() as client: r = await client.post( f"{self.base_url}/memory_bank/insert", json={ "bank_id": bank_id, "documents": [d.dict() for d in documents], }, headers={"Content-Type": "application/json"}, timeout=20, ) r.raise_for_status() async def query_documents( self, bank_id: str, query: InterleavedTextMedia, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: async with httpx.AsyncClient() as client: r = await client.post( f"{self.base_url}/memory_bank/query", json={ "bank_id": bank_id, "query": query, "params": params, }, headers={"Content-Type": "application/json"}, timeout=20, ) r.raise_for_status() return QueryDocumentsResponse(**r.json()) async def run_main(host: str, port: int, stream: bool): client = MemoryClient(f"http://{host}:{port}") # create a memory bank bank = await client.create_memory_bank( name="test_bank", config=VectorMemoryBankConfig( bank_id="test_bank", embedding_model="dragon-roberta-query-2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), ) print(bank) retrieved_bank = await client.get_memory_bank(bank.bank_id) assert retrieved_bank is not None assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2" urls = [ "memory_optimizations.rst", "chat.rst", "llama3.rst", "datasets.rst", "qat_finetune.rst", "lora_finetune.rst", ] documents = [ MemoryBankDocument( document_id=f"num-{i}", content=URL( uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}" ), mime_type="text/plain", ) for i, url in enumerate(urls) ] # insert some documents await client.insert_documents( bank_id=bank.bank_id, documents=documents, ) # query the documents response = await client.query_documents( bank_id=bank.bank_id, query=[ "How do I use Lora?", ], ) for chunk, score in zip(response.chunks, response.scores): print(f"Score: {score}") print(f"Chunk:\n========\n{chunk}\n========\n") response = await client.query_documents( bank_id=bank.bank_id, query=[ "Tell me more about llama3 and torchtune", ], ) for chunk, score in zip(response.chunks, response.scores): print(f"Score: {score}") print(f"Chunk:\n========\n{chunk}\n========\n") def main(host: str, port: int, stream: bool = True): asyncio.run(run_main(host, port, stream)) if __name__ == "__main__": fire.Fire(main)