agentic loop has a RAG implementation

This commit is contained in:
Ashwin Bharambe 2024-08-23 15:20:40 -07:00
parent 77d6055d9f
commit 14637bea66
4 changed files with 245 additions and 111 deletions

View file

@ -119,6 +119,7 @@ class ChatAgent(ShieldRunnerMixin):
steps = []
output_message = None
async for chunk in self.run(
session=session,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
@ -170,6 +171,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run(
self,
session: Session,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
@ -190,7 +192,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
turn_id, input_messages, attachments, sampling_params, stream
turn_id, session, input_messages, attachments, sampling_params, stream
):
if isinstance(res, bool):
return
@ -275,32 +277,62 @@ class ChatAgent(ShieldRunnerMixin):
)
)
async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> bool:
return self.agent_config.memory_configs or len(attachments) > 0
async def _retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> List[Message]:
return []
async def _run(
self,
session: Session,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
) -> AsyncGenerator:
need_context = await self._should_retrieve_context(input_messages, attachments)
if need_context:
context_messages = await self._retrieve_context(input_messages)
# input_messages = preprocess_dialog(input_messages, self.prefix_messages)
# input_messages = input_messages + context
input_messages = preprocess_dialog(input_messages)
enabled_tools = set(t.type for t in self.agent_config.tools)
need_rag_context = await self._should_retrieve_context(
input_messages, attachments
)
if need_rag_context:
step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
)
)
)
attachments = []
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
rag_context, bank_ids = await self._retrieve_context(input_messages)
step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
step_details=MemoryRetrievalStep(
memory_bank_ids=bank_ids,
inserted_context=rag_context,
),
)
)
)
if rag_context:
system_message = next(m for m in input_messages if m.role == "system")
if system_message:
system_message.content = system_message.content + "\n" + rag_context
else:
input_messages = [
Message(role="system", content=rag_context)
] + input_messages
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)]
input_messages.append(attachment_message(urls))
output_attachments = []
n_iter = 0
while True:
@ -414,7 +446,8 @@ class ChatAgent(ShieldRunnerMixin):
if len(message.tool_calls) == 0:
if stop_reason == StopReason.end_of_turn:
if len(attachments) > 0:
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
if isinstance(message.content, list):
message.content += attachments
else:
@ -526,58 +559,131 @@ class ChatAgent(ShieldRunnerMixin):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
attachments.append(result_message.content)
output_attachments.append(result_message.content)
elif isinstance(result_message.content, list) or isinstance(
result_message.content, tuple
):
for c in result_message.content:
if isinstance(c, Attachment):
attachments.append(c)
output_attachments.append(c)
input_messages = input_messages + [message, result_message]
n_iter += 1
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
if session.memory_bank is None:
session.memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session.session_id}",
config=VectorMemoryBankConfig(
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
),
)
def attachment_message(url: URL) -> ToolResponseMessage:
uri = url.uri
assert uri.startswith("file://")
filepath = uri[len("file://") :]
return session.memory_bank
async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> bool:
enabled_tools = set(t.type for t in self.agent_config.tools)
if attachments:
if (
AgenticSystemTool.code_interpreter.value in enabled_tools
and self.agent_config.tool_choice == ToolChoice.required
):
return False
return attachments or AgenticSystemTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
for t in self.agent_config.tools:
if t.type == AgenticSystemTool.memory.value:
return t
return None
async def _retrieve_context(
self, session: Session, messages: List[Message], attachments: List[Attachment]
) -> Optional[InterleavedTextMedia]:
bank_ids = []
memory = self._memory_tool_definition()
assert memory is not None, "Memory tool not configured"
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
if attachments:
bank = await self._ensure_memory_bank(session)
bank_ids.append(bank.bank_id)
documents = [
MemoryBankDocument(
doc_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in attachments
]
await self.memory_api.insert_documents(bank_id, documents)
assert len(bank_ids) > 0, "No memory banks configured?"
query = " ".join(m.content for m in messages)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
query=query,
params={
"max_chunks": 5,
},
)
for bank_id in bank_ids
]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
)
if not chunks:
return None
tokens = 0
picked = []
for c in chunks[: memory.max_chunks]:
tokens += c.token_count
if tokens > memory.max_tokens_in_context:
cprint(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
"red",
)
break
picked.append(c)
return [
"The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
def attachment_message(urls: List[URL]) -> ToolResponseMessage:
content = []
for url in urls:
uri = url.uri
assert uri.startswith("file://")
filepath = uri[len("file://") :]
content.append(f'# There is a file accessible to you at "{filepath}"\n')
return ToolResponseMessage(
call_id="",
tool_name=BuiltinTool.code_interpreter,
content=f'# There is a file accessible to you at "{filepath}"',
content=content,
)
def preprocess_dialog(messages: List[Message]) -> List[Message]:
# remove system message since those are
"""
Preprocesses the dialog by removing the system message and
adding the system message to the beginning of the dialog.
"""
ret = []
for m in messages:
if m.role == Role.system.value:
continue
# NOTE: the ideal behavior is to use `file_path = ...` but that
# means we need to have stateful execution of code which we currently
# do not have.
if isinstance(m.content, Attachment):
ret.append(attachment_message(m.content.url))
elif isinstance(m.content, list):
for c in m.content:
if isinstance(c, Attachment):
ret.append(attachment_message(c.url))
ret.append(m)
return ret
async def execute_tool_call_maybe(
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
) -> List[ToolResponseMessage]: