mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Rag Updates
This commit is contained in:
parent
ea6d9ec937
commit
f5620c09ad
4 changed files with 48 additions and 40 deletions
|
@ -8,7 +8,7 @@ import asyncio
|
|||
import copy
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List
|
||||
from typing import AsyncGenerator, List, Tuple
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
|
@ -53,6 +53,44 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_shields=agent_config.output_shields,
|
||||
)
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||
messages = []
|
||||
|
||||
# We do not want to keep adding RAG context to the input messages
|
||||
# May be this should be a parameter of the agentic instance
|
||||
# that can define its behavior in a custom way
|
||||
for m in turn.input_messages:
|
||||
msg = m.copy()
|
||||
if isinstance(msg, UserMessage):
|
||||
msg.context = None
|
||||
messages.append(msg)
|
||||
|
||||
# messages.extend(turn.input_messages)
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.inference.value:
|
||||
messages.append(step.model_response)
|
||||
elif step.step_type == StepType.tool_execution.value:
|
||||
for response in step.tool_responses:
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
call_id=response.call_id,
|
||||
tool_name=response.tool_name,
|
||||
content=response.content,
|
||||
)
|
||||
)
|
||||
elif step.step_type == StepType.shield_call.value:
|
||||
response = step.response
|
||||
if response.is_violation:
|
||||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
content=response.violation_return_message,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
# print_dialog(messages)
|
||||
return messages
|
||||
|
||||
def create_session(self, name: str) -> Session:
|
||||
session_id = str(uuid.uuid4())
|
||||
session = Session(
|
||||
|
@ -75,32 +113,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
messages = []
|
||||
for i, turn in enumerate(session.turns):
|
||||
# print(f"turn {i}")
|
||||
# print_dialog(turn.input_messages)
|
||||
messages.extend(turn.input_messages)
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.inference.value:
|
||||
messages.append(step.model_response)
|
||||
elif step.step_type == StepType.tool_execution.value:
|
||||
for response in step.tool_responses:
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
call_id=response.call_id,
|
||||
tool_name=response.tool_name,
|
||||
content=response.content,
|
||||
)
|
||||
)
|
||||
elif step.step_type == StepType.shield_call.value:
|
||||
response = step.response
|
||||
if response.is_violation:
|
||||
# TODO: Properly persist the
|
||||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
content=response.violation_return_message,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
|
||||
messages.extend(request.messages)
|
||||
|
||||
|
@ -325,15 +338,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
if rag_context:
|
||||
system_message = next(
|
||||
(m for m in input_messages if m.role == "system"), None
|
||||
)
|
||||
if system_message:
|
||||
system_message.content = system_message.content + "\n" + rag_context
|
||||
else:
|
||||
input_messages = [
|
||||
SystemMessage(content=rag_context)
|
||||
] + input_messages
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = "\n".join(rag_context)
|
||||
|
||||
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||
|
@ -610,7 +616,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _retrieve_context(
|
||||
self, session: Session, messages: List[Message], attachments: List[Attachment]
|
||||
) -> Optional[InterleavedTextMedia]:
|
||||
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
|
||||
bank_ids = []
|
||||
|
||||
memory = self._memory_tool_definition()
|
||||
|
@ -671,10 +677,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
"red",
|
||||
)
|
||||
break
|
||||
picked.append(c.content)
|
||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||
|
||||
return [
|
||||
"The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
*picked,
|
||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
], bank_ids
|
||||
|
|
|
@ -49,7 +49,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
Api.inference: providers[Api.inference]["meta-ollama"],
|
||||
Api.safety: providers[Api.safety]["meta-reference"],
|
||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||
Api.memory: remote_spec(Api.memory),
|
||||
Api.memory: providers[Api.memory]["meta-reference-faiss"],
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
|
|
|
@ -68,6 +68,7 @@ MemoryBankConfig = Annotated[
|
|||
class Chunk(BaseModel):
|
||||
content: InterleavedTextMedia
|
||||
token_count: int
|
||||
document_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -89,6 +89,7 @@ class BankState:
|
|||
self.chunk_by_index[indexlen + i] = Chunk(
|
||||
content=chunk[0],
|
||||
token_count=chunk[1],
|
||||
document_id=doc.document_id,
|
||||
)
|
||||
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
|
||||
self.id_by_index[indexlen + i] = doc.document_id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue