Rag Updates

This commit is contained in:
Hardik Shah 2024-08-27 20:09:33 -07:00
parent ea6d9ec937
commit f5620c09ad
4 changed files with 48 additions and 40 deletions

View file

@ -8,7 +8,7 @@ import asyncio
import copy
import uuid
from datetime import datetime
from typing import AsyncGenerator, List
from typing import AsyncGenerator, List, Tuple
from termcolor import cprint
@ -53,6 +53,44 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields,
)
def turn_to_messages(self, turn: Turn) -> List[Message]:
messages = []
# We do not want to keep adding RAG context to the input messages
# May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way
for m in turn.input_messages:
msg = m.copy()
if isinstance(msg, UserMessage):
msg.context = None
messages.append(msg)
# messages.extend(turn.input_messages)
for step in turn.steps:
if step.step_type == StepType.inference.value:
messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value:
for response in step.tool_responses:
messages.append(
ToolResponseMessage(
call_id=response.call_id,
tool_name=response.tool_name,
content=response.content,
)
)
elif step.step_type == StepType.shield_call.value:
response = step.response
if response.is_violation:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=response.violation_return_message,
stop_reason=StopReason.end_of_turn,
)
)
# print_dialog(messages)
return messages
def create_session(self, name: str) -> Session:
session_id = str(uuid.uuid4())
session = Session(
@ -75,32 +113,7 @@ class ChatAgent(ShieldRunnerMixin):
messages = []
for i, turn in enumerate(session.turns):
# print(f"turn {i}")
# print_dialog(turn.input_messages)
messages.extend(turn.input_messages)
for step in turn.steps:
if step.step_type == StepType.inference.value:
messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value:
for response in step.tool_responses:
messages.append(
ToolResponseMessage(
call_id=response.call_id,
tool_name=response.tool_name,
content=response.content,
)
)
elif step.step_type == StepType.shield_call.value:
response = step.response
if response.is_violation:
# TODO: Properly persist the
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=response.violation_return_message,
stop_reason=StopReason.end_of_turn,
)
)
messages.extend(self.turn_to_messages(turn))
messages.extend(request.messages)
@ -325,15 +338,8 @@ class ChatAgent(ShieldRunnerMixin):
)
if rag_context:
system_message = next(
(m for m in input_messages if m.role == "system"), None
)
if system_message:
system_message.content = system_message.content + "\n" + rag_context
else:
input_messages = [
SystemMessage(content=rag_context)
] + input_messages
last_message = input_messages[-1]
last_message.context = "\n".join(rag_context)
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)]
@ -610,7 +616,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _retrieve_context(
self, session: Session, messages: List[Message], attachments: List[Attachment]
) -> Optional[InterleavedTextMedia]:
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
bank_ids = []
memory = self._memory_tool_definition()
@ -671,10 +677,10 @@ class ChatAgent(ShieldRunnerMixin):
"red",
)
break
picked.append(c.content)
picked.append(f"id:{c.document_id}; content:{c.content}")
return [
"The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n",
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
], bank_ids

View file

@ -49,7 +49,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
Api.inference: providers[Api.inference]["meta-ollama"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
Api.memory: remote_spec(Api.memory),
Api.memory: providers[Api.memory]["meta-reference-faiss"],
},
),
DistributionSpec(

View file

@ -68,6 +68,7 @@ MemoryBankConfig = Annotated[
class Chunk(BaseModel):
content: InterleavedTextMedia
token_count: int
document_id: str
@json_schema_type

View file

@ -89,6 +89,7 @@ class BankState:
self.chunk_by_index[indexlen + i] = Chunk(
content=chunk[0],
token_count=chunk[1],
document_id=doc.document_id,
)
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
self.id_by_index[indexlen + i] = doc.document_id