mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
agentic loop has a RAG implementation
This commit is contained in:
parent
77d6055d9f
commit
14637bea66
4 changed files with 245 additions and 111 deletions
|
@ -119,6 +119,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session=session,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
attachments=request.attachments or [],
|
||||
|
@ -170,6 +171,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def run(
|
||||
self,
|
||||
session: Session,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
attachments: List[Attachment],
|
||||
|
@ -190,7 +192,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
turn_id, input_messages, attachments, sampling_params, stream
|
||||
turn_id, session, input_messages, attachments, sampling_params, stream
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
@ -275,32 +277,62 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
async def _should_retrieve_context(
|
||||
self, messages: List[Message], attachments: List[Attachment]
|
||||
) -> bool:
|
||||
return self.agent_config.memory_configs or len(attachments) > 0
|
||||
|
||||
async def _retrieve_context(
|
||||
self, messages: List[Message], attachments: List[Attachment]
|
||||
) -> List[Message]:
|
||||
return []
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
session: Session,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
attachments: List[Attachment],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
) -> AsyncGenerator:
|
||||
need_context = await self._should_retrieve_context(input_messages, attachments)
|
||||
if need_context:
|
||||
context_messages = await self._retrieve_context(input_messages)
|
||||
# input_messages = preprocess_dialog(input_messages, self.prefix_messages)
|
||||
# input_messages = input_messages + context
|
||||
input_messages = preprocess_dialog(input_messages)
|
||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||
need_rag_context = await self._should_retrieve_context(
|
||||
input_messages, attachments
|
||||
)
|
||||
if need_rag_context:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepStartPayload(
|
||||
step_type=StepType.memory_retrieval.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
attachments = []
|
||||
# TODO: find older context from the session and either replace it
|
||||
# or append with a sliding window. this is really a very simplistic implementation
|
||||
rag_context, bank_ids = await self._retrieve_context(input_messages)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
payload=AgenticSystemTurnResponseStepCompletePayload(
|
||||
step_type=StepType.memory_retrieval.value,
|
||||
step_id=step_id,
|
||||
step_details=MemoryRetrievalStep(
|
||||
memory_bank_ids=bank_ids,
|
||||
inserted_context=rag_context,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if rag_context:
|
||||
system_message = next(m for m in input_messages if m.role == "system")
|
||||
if system_message:
|
||||
system_message.content = system_message.content + "\n" + rag_context
|
||||
else:
|
||||
input_messages = [
|
||||
Message(role="system", content=rag_context)
|
||||
] + input_messages
|
||||
|
||||
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||
input_messages.append(attachment_message(urls))
|
||||
|
||||
output_attachments = []
|
||||
|
||||
n_iter = 0
|
||||
while True:
|
||||
|
@ -414,7 +446,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if len(message.tool_calls) == 0:
|
||||
if stop_reason == StopReason.end_of_turn:
|
||||
if len(attachments) > 0:
|
||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||
if len(output_attachments) > 0:
|
||||
if isinstance(message.content, list):
|
||||
message.content += attachments
|
||||
else:
|
||||
|
@ -526,58 +559,131 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
# attached file path etc. since the model is trained to only provide a user message
|
||||
# with the summary. We keep all generated attachments and then attach them to final message
|
||||
attachments.append(result_message.content)
|
||||
output_attachments.append(result_message.content)
|
||||
elif isinstance(result_message.content, list) or isinstance(
|
||||
result_message.content, tuple
|
||||
):
|
||||
for c in result_message.content:
|
||||
if isinstance(c, Attachment):
|
||||
attachments.append(c)
|
||||
output_attachments.append(c)
|
||||
|
||||
input_messages = input_messages + [message, result_message]
|
||||
|
||||
n_iter += 1
|
||||
|
||||
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
|
||||
if session.memory_bank is None:
|
||||
session.memory_bank = await self.memory_api.create_memory_bank(
|
||||
name=f"memory_bank_{session.session_id}",
|
||||
config=VectorMemoryBankConfig(
|
||||
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
|
||||
),
|
||||
)
|
||||
|
||||
def attachment_message(url: URL) -> ToolResponseMessage:
|
||||
uri = url.uri
|
||||
assert uri.startswith("file://")
|
||||
filepath = uri[len("file://") :]
|
||||
return session.memory_bank
|
||||
|
||||
async def _should_retrieve_context(
|
||||
self, messages: List[Message], attachments: List[Attachment]
|
||||
) -> bool:
|
||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||
if attachments:
|
||||
if (
|
||||
AgenticSystemTool.code_interpreter.value in enabled_tools
|
||||
and self.agent_config.tool_choice == ToolChoice.required
|
||||
):
|
||||
return False
|
||||
return attachments or AgenticSystemTool.memory.value in enabled_tools
|
||||
|
||||
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
||||
for t in self.agent_config.tools:
|
||||
if t.type == AgenticSystemTool.memory.value:
|
||||
return t
|
||||
|
||||
return None
|
||||
|
||||
async def _retrieve_context(
|
||||
self, session: Session, messages: List[Message], attachments: List[Attachment]
|
||||
) -> Optional[InterleavedTextMedia]:
|
||||
bank_ids = []
|
||||
|
||||
memory = self._memory_tool_definition()
|
||||
assert memory is not None, "Memory tool not configured"
|
||||
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
|
||||
|
||||
if attachments:
|
||||
bank = await self._ensure_memory_bank(session)
|
||||
bank_ids.append(bank.bank_id)
|
||||
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
doc_id=str(uuid.uuid4()),
|
||||
content=a.content,
|
||||
mime_type=a.mime_type,
|
||||
metadata={},
|
||||
)
|
||||
for a in attachments
|
||||
]
|
||||
await self.memory_api.insert_documents(bank_id, documents)
|
||||
|
||||
assert len(bank_ids) > 0, "No memory banks configured?"
|
||||
|
||||
query = " ".join(m.content for m in messages)
|
||||
tasks = [
|
||||
self.memory_api.query_documents(
|
||||
bank_id=bank_id,
|
||||
query=query,
|
||||
params={
|
||||
"max_chunks": 5,
|
||||
},
|
||||
)
|
||||
for bank_id in bank_ids
|
||||
]
|
||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(
|
||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||
)
|
||||
if not chunks:
|
||||
return None
|
||||
|
||||
tokens = 0
|
||||
picked = []
|
||||
for c in chunks[: memory.max_chunks]:
|
||||
tokens += c.token_count
|
||||
if tokens > memory.max_tokens_in_context:
|
||||
cprint(
|
||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||
"red",
|
||||
)
|
||||
break
|
||||
picked.append(c)
|
||||
|
||||
return [
|
||||
"The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
*picked,
|
||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
]
|
||||
|
||||
|
||||
def attachment_message(urls: List[URL]) -> ToolResponseMessage:
|
||||
content = []
|
||||
|
||||
for url in urls:
|
||||
uri = url.uri
|
||||
assert uri.startswith("file://")
|
||||
filepath = uri[len("file://") :]
|
||||
content.append(f'# There is a file accessible to you at "{filepath}"\n')
|
||||
|
||||
return ToolResponseMessage(
|
||||
call_id="",
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
content=f'# There is a file accessible to you at "{filepath}"',
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def preprocess_dialog(messages: List[Message]) -> List[Message]:
|
||||
# remove system message since those are
|
||||
"""
|
||||
Preprocesses the dialog by removing the system message and
|
||||
adding the system message to the beginning of the dialog.
|
||||
"""
|
||||
ret = []
|
||||
|
||||
for m in messages:
|
||||
if m.role == Role.system.value:
|
||||
continue
|
||||
|
||||
# NOTE: the ideal behavior is to use `file_path = ...` but that
|
||||
# means we need to have stateful execution of code which we currently
|
||||
# do not have.
|
||||
if isinstance(m.content, Attachment):
|
||||
ret.append(attachment_message(m.content.url))
|
||||
elif isinstance(m.content, list):
|
||||
for c in m.content:
|
||||
if isinstance(c, Attachment):
|
||||
ret.append(attachment_message(c.url))
|
||||
|
||||
ret.append(m)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
|
||||
) -> List[ToolResponseMessage]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue