From f5620c09ad6b29e136e1e1bfb81967dcb28bb701 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Tue, 27 Aug 2024 20:09:33 -0700 Subject: [PATCH] Rag Updates --- .../meta_reference/agent_instance.py | 84 ++++++++++--------- llama_toolchain/distribution/registry.py | 2 +- llama_toolchain/memory/api/api.py | 1 + .../memory/meta_reference/faiss/faiss.py | 1 + 4 files changed, 48 insertions(+), 40 deletions(-) diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 04560ff7b..1c75f0b83 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -8,7 +8,7 @@ import asyncio import copy import uuid from datetime import datetime -from typing import AsyncGenerator, List +from typing import AsyncGenerator, List, Tuple from termcolor import cprint @@ -53,6 +53,44 @@ class ChatAgent(ShieldRunnerMixin): output_shields=agent_config.output_shields, ) + def turn_to_messages(self, turn: Turn) -> List[Message]: + messages = [] + + # We do not want to keep adding RAG context to the input messages + # May be this should be a parameter of the agentic instance + # that can define its behavior in a custom way + for m in turn.input_messages: + msg = m.copy() + if isinstance(msg, UserMessage): + msg.context = None + messages.append(msg) + + # messages.extend(turn.input_messages) + for step in turn.steps: + if step.step_type == StepType.inference.value: + messages.append(step.model_response) + elif step.step_type == StepType.tool_execution.value: + for response in step.tool_responses: + messages.append( + ToolResponseMessage( + call_id=response.call_id, + tool_name=response.tool_name, + content=response.content, + ) + ) + elif step.step_type == StepType.shield_call.value: + response = step.response + if response.is_violation: + # CompletionMessage itself in the ShieldResponse + messages.append( + CompletionMessage( + content=response.violation_return_message, + stop_reason=StopReason.end_of_turn, + ) + ) + # print_dialog(messages) + return messages + def create_session(self, name: str) -> Session: session_id = str(uuid.uuid4()) session = Session( @@ -75,32 +113,7 @@ class ChatAgent(ShieldRunnerMixin): messages = [] for i, turn in enumerate(session.turns): - # print(f"turn {i}") - # print_dialog(turn.input_messages) - messages.extend(turn.input_messages) - for step in turn.steps: - if step.step_type == StepType.inference.value: - messages.append(step.model_response) - elif step.step_type == StepType.tool_execution.value: - for response in step.tool_responses: - messages.append( - ToolResponseMessage( - call_id=response.call_id, - tool_name=response.tool_name, - content=response.content, - ) - ) - elif step.step_type == StepType.shield_call.value: - response = step.response - if response.is_violation: - # TODO: Properly persist the - # CompletionMessage itself in the ShieldResponse - messages.append( - CompletionMessage( - content=response.violation_return_message, - stop_reason=StopReason.end_of_turn, - ) - ) + messages.extend(self.turn_to_messages(turn)) messages.extend(request.messages) @@ -325,15 +338,8 @@ class ChatAgent(ShieldRunnerMixin): ) if rag_context: - system_message = next( - (m for m in input_messages if m.role == "system"), None - ) - if system_message: - system_message.content = system_message.content + "\n" + rag_context - else: - input_messages = [ - SystemMessage(content=rag_context) - ] + input_messages + last_message = input_messages[-1] + last_message.context = "\n".join(rag_context) elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools: urls = [a.content for a in attachments if isinstance(a.content, URL)] @@ -610,7 +616,7 @@ class ChatAgent(ShieldRunnerMixin): async def _retrieve_context( self, session: Session, messages: List[Message], attachments: List[Attachment] - ) -> Optional[InterleavedTextMedia]: + ) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids) bank_ids = [] memory = self._memory_tool_definition() @@ -671,10 +677,10 @@ class ChatAgent(ShieldRunnerMixin): "red", ) break - picked.append(c.content) + picked.append(f"id:{c.document_id}; content:{c.content}") return [ - "The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n", + "Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", *picked, "\n=== END-RETRIEVED-CONTEXT ===\n", ], bank_ids diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index fd9bc6e1b..bce702472 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -49,7 +49,7 @@ def available_distribution_specs() -> List[DistributionSpec]: Api.inference: providers[Api.inference]["meta-ollama"], Api.safety: providers[Api.safety]["meta-reference"], Api.agentic_system: providers[Api.agentic_system]["meta-reference"], - Api.memory: remote_spec(Api.memory), + Api.memory: providers[Api.memory]["meta-reference-faiss"], }, ), DistributionSpec( diff --git a/llama_toolchain/memory/api/api.py b/llama_toolchain/memory/api/api.py index 57370bdf3..70c7aa7ec 100644 --- a/llama_toolchain/memory/api/api.py +++ b/llama_toolchain/memory/api/api.py @@ -68,6 +68,7 @@ MemoryBankConfig = Annotated[ class Chunk(BaseModel): content: InterleavedTextMedia token_count: int + document_id: str @json_schema_type diff --git a/llama_toolchain/memory/meta_reference/faiss/faiss.py b/llama_toolchain/memory/meta_reference/faiss/faiss.py index 8f3ef2a4c..6a168d330 100644 --- a/llama_toolchain/memory/meta_reference/faiss/faiss.py +++ b/llama_toolchain/memory/meta_reference/faiss/faiss.py @@ -89,6 +89,7 @@ class BankState: self.chunk_by_index[indexlen + i] = Chunk( content=chunk[0], token_count=chunk[1], + document_id=doc.document_id, ) print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}") self.id_by_index[indexlen + i] = doc.document_id