Rag Updates

This commit is contained in:
Hardik Shah 2024-08-27 20:09:33 -07:00
parent ea6d9ec937
commit f5620c09ad
4 changed files with 48 additions and 40 deletions

View file

@ -8,7 +8,7 @@ import asyncio
import copy import copy
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator, List from typing import AsyncGenerator, List, Tuple
from termcolor import cprint from termcolor import cprint
@ -53,6 +53,44 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields, output_shields=agent_config.output_shields,
) )
def turn_to_messages(self, turn: Turn) -> List[Message]:
messages = []
# We do not want to keep adding RAG context to the input messages
# May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way
for m in turn.input_messages:
msg = m.copy()
if isinstance(msg, UserMessage):
msg.context = None
messages.append(msg)
# messages.extend(turn.input_messages)
for step in turn.steps:
if step.step_type == StepType.inference.value:
messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value:
for response in step.tool_responses:
messages.append(
ToolResponseMessage(
call_id=response.call_id,
tool_name=response.tool_name,
content=response.content,
)
)
elif step.step_type == StepType.shield_call.value:
response = step.response
if response.is_violation:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=response.violation_return_message,
stop_reason=StopReason.end_of_turn,
)
)
# print_dialog(messages)
return messages
def create_session(self, name: str) -> Session: def create_session(self, name: str) -> Session:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
session = Session( session = Session(
@ -75,32 +113,7 @@ class ChatAgent(ShieldRunnerMixin):
messages = [] messages = []
for i, turn in enumerate(session.turns): for i, turn in enumerate(session.turns):
# print(f"turn {i}") messages.extend(self.turn_to_messages(turn))
# print_dialog(turn.input_messages)
messages.extend(turn.input_messages)
for step in turn.steps:
if step.step_type == StepType.inference.value:
messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value:
for response in step.tool_responses:
messages.append(
ToolResponseMessage(
call_id=response.call_id,
tool_name=response.tool_name,
content=response.content,
)
)
elif step.step_type == StepType.shield_call.value:
response = step.response
if response.is_violation:
# TODO: Properly persist the
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=response.violation_return_message,
stop_reason=StopReason.end_of_turn,
)
)
messages.extend(request.messages) messages.extend(request.messages)
@ -325,15 +338,8 @@ class ChatAgent(ShieldRunnerMixin):
) )
if rag_context: if rag_context:
system_message = next( last_message = input_messages[-1]
(m for m in input_messages if m.role == "system"), None last_message.context = "\n".join(rag_context)
)
if system_message:
system_message.content = system_message.content + "\n" + rag_context
else:
input_messages = [
SystemMessage(content=rag_context)
] + input_messages
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools: elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)] urls = [a.content for a in attachments if isinstance(a.content, URL)]
@ -610,7 +616,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _retrieve_context( async def _retrieve_context(
self, session: Session, messages: List[Message], attachments: List[Attachment] self, session: Session, messages: List[Message], attachments: List[Attachment]
) -> Optional[InterleavedTextMedia]: ) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
bank_ids = [] bank_ids = []
memory = self._memory_tool_definition() memory = self._memory_tool_definition()
@ -671,10 +677,10 @@ class ChatAgent(ShieldRunnerMixin):
"red", "red",
) )
break break
picked.append(c.content) picked.append(f"id:{c.document_id}; content:{c.content}")
return [ return [
"The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n", "Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked, *picked,
"\n=== END-RETRIEVED-CONTEXT ===\n", "\n=== END-RETRIEVED-CONTEXT ===\n",
], bank_ids ], bank_ids

View file

@ -49,7 +49,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
Api.inference: providers[Api.inference]["meta-ollama"], Api.inference: providers[Api.inference]["meta-ollama"],
Api.safety: providers[Api.safety]["meta-reference"], Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"], Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
Api.memory: remote_spec(Api.memory), Api.memory: providers[Api.memory]["meta-reference-faiss"],
}, },
), ),
DistributionSpec( DistributionSpec(

View file

@ -68,6 +68,7 @@ MemoryBankConfig = Annotated[
class Chunk(BaseModel): class Chunk(BaseModel):
content: InterleavedTextMedia content: InterleavedTextMedia
token_count: int token_count: int
document_id: str
@json_schema_type @json_schema_type

View file

@ -89,6 +89,7 @@ class BankState:
self.chunk_by_index[indexlen + i] = Chunk( self.chunk_by_index[indexlen + i] = Chunk(
content=chunk[0], content=chunk[0],
token_count=chunk[1], token_count=chunk[1],
document_id=doc.document_id,
) )
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}") print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
self.id_by_index[indexlen + i] = doc.document_id self.id_by_index[indexlen + i] = doc.document_id