From 58e2feceb0dad42e4d9f7e3fe050ff657dcdf8e8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 24 Aug 2024 23:36:58 -0700 Subject: [PATCH] basic RAG seems to work --- llama_toolchain/agentic_system/client.py | 97 ++++++++++++------- .../meta_reference/agent_instance.py | 22 +++-- llama_toolchain/inference/prepare_messages.py | 21 +++- 3 files changed, 96 insertions(+), 44 deletions(-) diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index e2adac495..8fa82eb5d 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -82,6 +82,10 @@ class AgenticSystemClient(AgenticSystem): if line.startswith("data:"): data = line[len("data: ") :] try: + if "error" in data: + cprint(data, "red") + continue + yield AgenticSystemTurnResponseStreamChunk( **json.loads(data) ) @@ -90,8 +94,41 @@ class AgenticSystemClient(AgenticSystem): print(f"Error with parsing or validation: {e}") +async def _run_agent(api, tool_definitions, user_prompts): + agent_config = AgentConfig( + model="Meta-Llama3.1-8B-Instruct", + instructions="You are a helpful assistant", + sampling_params=SamplingParams(temperature=1.0, top_p=0.9), + tools=tool_definitions, + tool_choice=ToolChoice.auto, + tool_prompt_format=ToolPromptFormat.function_tag, + ) + + create_response = await api.create_agentic_system(agent_config) + session_response = await api.create_agentic_system_session( + agent_id=create_response.agent_id, + session_name="test_session", + ) + + for content in user_prompts: + cprint(f"User> {content}", color="blue") + iterator = api.create_agentic_system_turn( + AgenticSystemTurnCreateRequest( + agent_id=create_response.agent_id, + session_id=session_response.session_id, + messages=[ + UserMessage(content=content), + ], + stream=True, + ) + ) + + async for event, log in EventLogger().log(iterator): + if log is not None: + log.print() + + async def run_main(host: str, port: int): - # client to test remote impl of agentic system api = AgenticSystemClient(f"http://{host}:{port}") tool_definitions = [ @@ -118,24 +155,6 @@ async def run_main(host: str, port: int): ), ] - agent_config = AgentConfig( - model="Meta-Llama3.1-8B-Instruct", - instructions="You are a helpful assistant", - sampling_params=SamplingParams(temperature=1.0, top_p=0.9), - tools=tool_definitions, - tool_choice=ToolChoice.auto, - tool_prompt_format=ToolPromptFormat.function_tag, - ) - - create_response = await api.create_agentic_system(agent_config) - print(create_response) - - session_response = await api.create_agentic_system_session( - agent_id=create_response.agent_id, - session_name="test_session", - ) - print(session_response) - user_prompts = [ "Who are you?", "what is the 100th prime number?", @@ -143,22 +162,32 @@ async def run_main(host: str, port: int): "Write code to check if a number is prime. Use that to check if 7 is prime", "What is the boiling point of polyjuicepotion ?", ] - for content in user_prompts: - cprint(f"User> {content}", color="blue") - iterator = api.create_agentic_system_turn( - AgenticSystemTurnCreateRequest( - agent_id=create_response.agent_id, - session_id=session_response.session_id, - messages=[ - UserMessage(content=content), - ], - stream=True, - ) - ) + await _run_agent(api, tool_definitions, user_prompts) - async for event, log in EventLogger().log(iterator): - if log is not None: - log.print() + +async def run_rag(host: str, port: int): + api = AgenticSystemClient(f"http://{host}:{port}") + + # NOTE: for this, I ran `llama_toolchain.memory.client` first which + # populated the memory banks with torchtune docs. Then grabbed the bank_id + # from the output of that run. + tool_definitions = [ + MemoryToolDefinition( + max_tokens_in_context=2048, + memory_bank_configs=[ + AgenticSystemVectorMemoryBankConfig( + bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed", + ) + ], + ), + ] + + user_prompts = [ + "How do I use Lora?", + "Tell me briefly about llama3 and torchtune", + ] + + await _run_agent(api, tool_definitions, user_prompts) def main(host: str, port: int): diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 427f96be8..6bd681812 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +import asyncio import copy import uuid from datetime import datetime @@ -304,7 +304,9 @@ class ChatAgent(ShieldRunnerMixin): # TODO: find older context from the session and either replace it # or append with a sliding window. this is really a very simplistic implementation - rag_context, bank_ids = await self._retrieve_context(input_messages) + rag_context, bank_ids = await self._retrieve_context( + session, input_messages, attachments + ) step_id = str(uuid.uuid4()) yield AgenticSystemTurnResponseStreamChunk( @@ -313,20 +315,24 @@ class ChatAgent(ShieldRunnerMixin): step_type=StepType.memory_retrieval.value, step_id=step_id, step_details=MemoryRetrievalStep( + turn_id=turn_id, + step_id=step_id, memory_bank_ids=bank_ids, - inserted_context=rag_context, + inserted_context=rag_context or "", ), ) ) ) if rag_context: - system_message = next(m for m in input_messages if m.role == "system") + system_message = next( + (m for m in input_messages if m.role == "system"), None + ) if system_message: system_message.content = system_message.content + "\n" + rag_context else: input_messages = [ - Message(role="system", content=rag_context) + SystemMessage(content=rag_context) ] + input_messages elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools: @@ -644,7 +650,7 @@ class ChatAgent(ShieldRunnerMixin): *sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True) ) if not chunks: - return None + return None, bank_ids tokens = 0 picked = [] @@ -656,13 +662,13 @@ class ChatAgent(ShieldRunnerMixin): "red", ) break - picked.append(c) + picked.append(c.content) return [ "The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n", *picked, "\n=== END-RETRIEVED-CONTEXT ===\n", - ] + ], bank_ids def _get_tools(self) -> List[ToolDefinition]: ret = [] diff --git a/llama_toolchain/inference/prepare_messages.py b/llama_toolchain/inference/prepare_messages.py index d5ce648e1..92e94f8d2 100644 --- a/llama_toolchain/inference/prepare_messages.py +++ b/llama_toolchain/inference/prepare_messages.py @@ -1,4 +1,8 @@ -import textwrap +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.inference.api import * # noqa: F403 @@ -41,8 +45,21 @@ def prepare_messages(request: ChatCompletionRequest) -> List[Message]: sys_content += default_template.render() if existing_system_message: + # TODO: this fn is needed in many places + def _process(c): + if isinstance(c, str): + return c + else: + return "" + sys_content += "\n" - sys_content += existing_system_message.content + + if isinstance(existing_system_message.content, str): + sys_content += _process(existing_system_message.content) + elif isinstance(existing_system_message.content, list): + sys_content += "\n".join( + [_process(c) for c in existing_system_message.content] + ) messages.append(SystemMessage(content=sys_content))