diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index 8fa82eb5d..f4ae21da4 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -94,7 +94,7 @@ class AgenticSystemClient(AgenticSystem): print(f"Error with parsing or validation: {e}") -async def _run_agent(api, tool_definitions, user_prompts): +async def _run_agent(api, tool_definitions, user_prompts, attachments=None): agent_config = AgentConfig( model="Meta-Llama3.1-8B-Instruct", instructions="You are a helpful assistant", @@ -119,6 +119,7 @@ async def _run_agent(api, tool_definitions, user_prompts): messages=[ UserMessage(content=content), ], + attachments=attachments, stream=True, ) ) @@ -168,17 +169,36 @@ async def run_main(host: str, port: int): async def run_rag(host: str, port: int): api = AgenticSystemClient(f"http://{host}:{port}") - # NOTE: for this, I ran `llama_toolchain.memory.client` first which - # populated the memory banks with torchtune docs. Then grabbed the bank_id + urls = [ + "memory_optimizations.rst", + "chat.rst", + "llama3.rst", + "datasets.rst", + "qat_finetune.rst", + "lora_finetune.rst", + ] + attachments = [ + Attachment( + content=URL( + uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}" + ), + mime_type="text/plain", + ) + for i, url in enumerate(urls) + ] + + # Alternatively, you can pre-populate the memory bank with documents for example, + # using `llama_toolchain.memory.client`. Then you can grab the bank_id # from the output of that run. tool_definitions = [ MemoryToolDefinition( max_tokens_in_context=2048, - memory_bank_configs=[ - AgenticSystemVectorMemoryBankConfig( - bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed", - ) - ], + memory_bank_configs=[], + # memory_bank_configs=[ + # AgenticSystemVectorMemoryBankConfig( + # bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed", + # ) + # ], ), ] @@ -187,11 +207,11 @@ async def run_rag(host: str, port: int): "Tell me briefly about llama3 and torchtune", ] - await _run_agent(api, tool_definitions, user_prompts) + await _run_agent(api, tool_definitions, user_prompts, attachments) def main(host: str, port: int): - asyncio.run(run_main(host, port)) + asyncio.run(run_rag(host, port)) if __name__ == "__main__": diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 6bd681812..127ebf18e 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -580,6 +580,7 @@ class ChatAgent(ShieldRunnerMixin): name=f"memory_bank_{session.session_id}", config=VectorMemoryBankConfig( embedding_model="sentence-transformer/all-MiniLM-L6-v2", + chunk_size_in_tokens=512, ), ) @@ -619,7 +620,7 @@ class ChatAgent(ShieldRunnerMixin): documents = [ MemoryBankDocument( - doc_id=str(uuid.uuid4()), + document_id=str(uuid.uuid4()), content=a.content, mime_type=a.mime_type, metadata={},