mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
basic RAG seems to work
This commit is contained in:
parent
830252257b
commit
58e2feceb0
3 changed files with 96 additions and 44 deletions
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
@ -304,7 +304,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# TODO: find older context from the session and either replace it
|
||||
# or append with a sliding window. this is really a very simplistic implementation
|
||||
rag_context, bank_ids = await self._retrieve_context(input_messages)
|
||||
rag_context, bank_ids = await self._retrieve_context(
|
||||
session, input_messages, attachments
|
||||
)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
|
@ -313,20 +315,24 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_type=StepType.memory_retrieval.value,
|
||||
step_id=step_id,
|
||||
step_details=MemoryRetrievalStep(
|
||||
turn_id=turn_id,
|
||||
step_id=step_id,
|
||||
memory_bank_ids=bank_ids,
|
||||
inserted_context=rag_context,
|
||||
inserted_context=rag_context or "",
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if rag_context:
|
||||
system_message = next(m for m in input_messages if m.role == "system")
|
||||
system_message = next(
|
||||
(m for m in input_messages if m.role == "system"), None
|
||||
)
|
||||
if system_message:
|
||||
system_message.content = system_message.content + "\n" + rag_context
|
||||
else:
|
||||
input_messages = [
|
||||
Message(role="system", content=rag_context)
|
||||
SystemMessage(content=rag_context)
|
||||
] + input_messages
|
||||
|
||||
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
||||
|
@ -644,7 +650,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||
)
|
||||
if not chunks:
|
||||
return None
|
||||
return None, bank_ids
|
||||
|
||||
tokens = 0
|
||||
picked = []
|
||||
|
@ -656,13 +662,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
"red",
|
||||
)
|
||||
break
|
||||
picked.append(c)
|
||||
picked.append(c.content)
|
||||
|
||||
return [
|
||||
"The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
*picked,
|
||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
]
|
||||
], bank_ids
|
||||
|
||||
def _get_tools(self) -> List[ToolDefinition]:
|
||||
ret = []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue