agentic loop has a RAG implementation

This commit is contained in:
Ashwin Bharambe 2024-08-23 15:20:40 -07:00
parent 77d6055d9f
commit 14637bea66
4 changed files with 245 additions and 111 deletions

View file

@ -13,6 +13,7 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.common.deployment_types import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403
@ -21,7 +22,7 @@ from llama_toolchain.memory.api import * # noqa: F403
@json_schema_type
class Attachment(BaseModel):
url: URL
content: InterleavedTextMedia | URL
mime_type: str
@ -81,10 +82,45 @@ class FunctionCallToolDefinition(ToolDefinitionCommon):
remote_execution: Optional[RestAPIExecutionConfig] = None
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
keys: List[str] # what keys to focus on
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
AgenticSystemVectorMemoryBankConfig,
AgenticSystemKeyValueMemoryBankConfig,
AgenticSystemKeywordMemoryBankConfig,
AgenticSystemGraphMemoryBankConfig,
],
Field(discriminator="type"),
]
@json_schema_type
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
memory_banks: List[MemoryBank] = Field(default_factory=list)
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
max_tokens_in_context: int = 4096
max_chunks: int = 10
AgenticSystemToolDefinition = Annotated[
@ -141,8 +177,7 @@ class MemoryRetrievalStep(StepCommon):
StepType.memory_retrieval.value
)
memory_bank_ids: List[str]
documents: List[MemoryBankDocument]
scores: List[float]
inserted_context: InterleavedTextMedia
Step = Annotated[
@ -185,38 +220,7 @@ class Session(BaseModel):
turns: List[Turn]
started_at: datetime
class MemoryBankConfigCommon(BaseModel):
bank_id: str
class VectorMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
class KeyValueMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
keys: List[str] # what keys to focus on
class KeywordMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class GraphMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
memory_bank: Optional[MemoryBank] = None
class AgentConfigCommon(BaseModel):
@ -236,7 +240,6 @@ class AgentConfigCommon(BaseModel):
class AgentConfig(AgentConfigCommon):
model: str
instructions: str
memory_bank_configs: Optional[List[MemoryBankConfig]] = Field(default_factory=list)
class AgentConfigOverridablePerTurn(AgentConfigCommon):

View file

@ -119,6 +119,7 @@ class ChatAgent(ShieldRunnerMixin):
steps = []
output_message = None
async for chunk in self.run(
session=session,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
@ -170,6 +171,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run(
self,
session: Session,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
@ -190,7 +192,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
turn_id, input_messages, attachments, sampling_params, stream
turn_id, session, input_messages, attachments, sampling_params, stream
):
if isinstance(res, bool):
return
@ -275,32 +277,62 @@ class ChatAgent(ShieldRunnerMixin):
)
)
async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> bool:
return self.agent_config.memory_configs or len(attachments) > 0
async def _retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> List[Message]:
return []
async def _run(
self,
session: Session,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
) -> AsyncGenerator:
need_context = await self._should_retrieve_context(input_messages, attachments)
if need_context:
context_messages = await self._retrieve_context(input_messages)
# input_messages = preprocess_dialog(input_messages, self.prefix_messages)
# input_messages = input_messages + context
input_messages = preprocess_dialog(input_messages)
enabled_tools = set(t.type for t in self.agent_config.tools)
need_rag_context = await self._should_retrieve_context(
input_messages, attachments
)
if need_rag_context:
step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
)
)
)
attachments = []
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
rag_context, bank_ids = await self._retrieve_context(input_messages)
step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
step_details=MemoryRetrievalStep(
memory_bank_ids=bank_ids,
inserted_context=rag_context,
),
)
)
)
if rag_context:
system_message = next(m for m in input_messages if m.role == "system")
if system_message:
system_message.content = system_message.content + "\n" + rag_context
else:
input_messages = [
Message(role="system", content=rag_context)
] + input_messages
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)]
input_messages.append(attachment_message(urls))
output_attachments = []
n_iter = 0
while True:
@ -414,7 +446,8 @@ class ChatAgent(ShieldRunnerMixin):
if len(message.tool_calls) == 0:
if stop_reason == StopReason.end_of_turn:
if len(attachments) > 0:
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
if isinstance(message.content, list):
message.content += attachments
else:
@ -526,58 +559,131 @@ class ChatAgent(ShieldRunnerMixin):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
attachments.append(result_message.content)
output_attachments.append(result_message.content)
elif isinstance(result_message.content, list) or isinstance(
result_message.content, tuple
):
for c in result_message.content:
if isinstance(c, Attachment):
attachments.append(c)
output_attachments.append(c)
input_messages = input_messages + [message, result_message]
n_iter += 1
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
if session.memory_bank is None:
session.memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session.session_id}",
config=VectorMemoryBankConfig(
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
),
)
def attachment_message(url: URL) -> ToolResponseMessage:
return session.memory_bank
async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> bool:
enabled_tools = set(t.type for t in self.agent_config.tools)
if attachments:
if (
AgenticSystemTool.code_interpreter.value in enabled_tools
and self.agent_config.tool_choice == ToolChoice.required
):
return False
return attachments or AgenticSystemTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
for t in self.agent_config.tools:
if t.type == AgenticSystemTool.memory.value:
return t
return None
async def _retrieve_context(
self, session: Session, messages: List[Message], attachments: List[Attachment]
) -> Optional[InterleavedTextMedia]:
bank_ids = []
memory = self._memory_tool_definition()
assert memory is not None, "Memory tool not configured"
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
if attachments:
bank = await self._ensure_memory_bank(session)
bank_ids.append(bank.bank_id)
documents = [
MemoryBankDocument(
doc_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in attachments
]
await self.memory_api.insert_documents(bank_id, documents)
assert len(bank_ids) > 0, "No memory banks configured?"
query = " ".join(m.content for m in messages)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
query=query,
params={
"max_chunks": 5,
},
)
for bank_id in bank_ids
]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
)
if not chunks:
return None
tokens = 0
picked = []
for c in chunks[: memory.max_chunks]:
tokens += c.token_count
if tokens > memory.max_tokens_in_context:
cprint(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
"red",
)
break
picked.append(c)
return [
"The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
def attachment_message(urls: List[URL]) -> ToolResponseMessage:
content = []
for url in urls:
uri = url.uri
assert uri.startswith("file://")
filepath = uri[len("file://") :]
content.append(f'# There is a file accessible to you at "{filepath}"\n')
return ToolResponseMessage(
call_id="",
tool_name=BuiltinTool.code_interpreter,
content=f'# There is a file accessible to you at "{filepath}"',
content=content,
)
def preprocess_dialog(messages: List[Message]) -> List[Message]:
# remove system message since those are
"""
Preprocesses the dialog by removing the system message and
adding the system message to the beginning of the dialog.
"""
ret = []
for m in messages:
if m.role == Role.system.value:
continue
# NOTE: the ideal behavior is to use `file_path = ...` but that
# means we need to have stateful execution of code which we currently
# do not have.
if isinstance(m.content, Attachment):
ret.append(attachment_message(m.content.url))
elif isinstance(m.content, list):
for c in m.content:
if isinstance(c, Attachment):
ret.append(attachment_message(c.url))
ret.append(m)
return ret
async def execute_tool_call_maybe(
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
) -> List[ToolResponseMessage]:

View file

@ -101,6 +101,11 @@ class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
@json_schema_type
class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
class Inference(Protocol):
@webmethod(route="/inference/completion")
async def completion(
@ -114,6 +119,13 @@ class Inference(Protocol):
request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
@webmethod(route="/inference/embeddings")
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ...
@webmethod(route="/inference/batch_completion")
async def batch_completion(
self,

View file

@ -23,17 +23,6 @@ class MemoryBankDocument(BaseModel):
metadata: Dict[str, Any]
class Chunk(BaseModel):
content: InterleavedTextMedia
token_count: int
@json_schema_type
class QueryDocumentsResponse(BaseModel):
chunks: List[Chunk]
scores: List[float]
@json_schema_type
class MemoryBankType(Enum):
vector = "vector"
@ -45,6 +34,7 @@ class MemoryBankType(Enum):
class VectorMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
class KeyValueMemoryBankConfig(BaseModel):
@ -70,18 +60,39 @@ MemoryBankConfig = Annotated[
]
class Chunk(BaseModel):
content: InterleavedTextMedia
token_count: int
@json_schema_type
class QueryDocumentsResponse(BaseModel):
chunks: List[Chunk]
scores: List[float]
@json_schema_type
class QueryAPI(Protocol):
@webmethod(route="/query_documents")
def query_documents(
self,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@json_schema_type
class MemoryBank(BaseModel):
bank_id: str
name: str
config: MemoryBankConfig
# if there's a pre-existing store which obeys the MemoryBank REST interface
# if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI
url: Optional[URL] = None
class Memory(Protocol):
@webmethod(route="/memory_banks/create")
def create_memory_bank(
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
@ -89,33 +100,35 @@ class Memory(Protocol):
) -> MemoryBank: ...
@webmethod(route="/memory_banks/list", method="GET")
def list_memory_banks(self) -> List[MemoryBank]: ...
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get")
def get_memory_bank(self, bank_id: str) -> MemoryBank: ...
async def get_memory_bank(self, bank_id: str) -> MemoryBank: ...
@webmethod(route="/memory_banks/drop", method="DELETE")
def drop_memory_bank(
async def drop_memory_bank(
self,
bank_id: str,
) -> str: ...
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@webmethod(route="/memory_bank/insert")
def insert_documents(
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/update")
def update_documents(
async def update_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/query")
def query_documents(
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
@ -123,14 +136,14 @@ class Memory(Protocol):
) -> QueryDocumentsResponse: ...
@webmethod(route="/memory_bank/documents/get")
def get_documents(
async def get_documents(
self,
bank_id: str,
document_ids: List[str],
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/documents/delete")
def delete_documents(
async def delete_documents(
self,
bank_id: str,
document_ids: List[str],