agentic loop has a RAG implementation

This commit is contained in:
Ashwin Bharambe 2024-08-23 15:20:40 -07:00
parent 77d6055d9f
commit 14637bea66
4 changed files with 245 additions and 111 deletions

View file

@ -13,6 +13,7 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.common.deployment_types import * # noqa: F403 from llama_toolchain.common.deployment_types import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403 from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403 from llama_toolchain.safety.api import * # noqa: F403
@ -21,7 +22,7 @@ from llama_toolchain.memory.api import * # noqa: F403
@json_schema_type @json_schema_type
class Attachment(BaseModel): class Attachment(BaseModel):
url: URL content: InterleavedTextMedia | URL
mime_type: str mime_type: str
@ -81,10 +82,45 @@ class FunctionCallToolDefinition(ToolDefinitionCommon):
remote_execution: Optional[RestAPIExecutionConfig] = None remote_execution: Optional[RestAPIExecutionConfig] = None
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
keys: List[str] # what keys to focus on
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
AgenticSystemVectorMemoryBankConfig,
AgenticSystemKeyValueMemoryBankConfig,
AgenticSystemKeywordMemoryBankConfig,
AgenticSystemGraphMemoryBankConfig,
],
Field(discriminator="type"),
]
@json_schema_type @json_schema_type
class MemoryToolDefinition(ToolDefinitionCommon): class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
memory_banks: List[MemoryBank] = Field(default_factory=list) memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
max_tokens_in_context: int = 4096
max_chunks: int = 10
AgenticSystemToolDefinition = Annotated[ AgenticSystemToolDefinition = Annotated[
@ -141,8 +177,7 @@ class MemoryRetrievalStep(StepCommon):
StepType.memory_retrieval.value StepType.memory_retrieval.value
) )
memory_bank_ids: List[str] memory_bank_ids: List[str]
documents: List[MemoryBankDocument] inserted_context: InterleavedTextMedia
scores: List[float]
Step = Annotated[ Step = Annotated[
@ -185,38 +220,7 @@ class Session(BaseModel):
turns: List[Turn] turns: List[Turn]
started_at: datetime started_at: datetime
memory_bank: Optional[MemoryBank] = None
class MemoryBankConfigCommon(BaseModel):
bank_id: str
class VectorMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
class KeyValueMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
keys: List[str] # what keys to focus on
class KeywordMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class GraphMemoryBankConfig(MemoryBankConfigCommon):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class AgentConfigCommon(BaseModel): class AgentConfigCommon(BaseModel):
@ -236,7 +240,6 @@ class AgentConfigCommon(BaseModel):
class AgentConfig(AgentConfigCommon): class AgentConfig(AgentConfigCommon):
model: str model: str
instructions: str instructions: str
memory_bank_configs: Optional[List[MemoryBankConfig]] = Field(default_factory=list)
class AgentConfigOverridablePerTurn(AgentConfigCommon): class AgentConfigOverridablePerTurn(AgentConfigCommon):

View file

@ -119,6 +119,7 @@ class ChatAgent(ShieldRunnerMixin):
steps = [] steps = []
output_message = None output_message = None
async for chunk in self.run( async for chunk in self.run(
session=session,
turn_id=turn_id, turn_id=turn_id,
input_messages=messages, input_messages=messages,
attachments=request.attachments or [], attachments=request.attachments or [],
@ -170,6 +171,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run( async def run(
self, self,
session: Session,
turn_id: str, turn_id: str,
input_messages: List[Message], input_messages: List[Message],
attachments: List[Attachment], attachments: List[Attachment],
@ -190,7 +192,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res yield res
async for res in self._run( async for res in self._run(
turn_id, input_messages, attachments, sampling_params, stream turn_id, session, input_messages, attachments, sampling_params, stream
): ):
if isinstance(res, bool): if isinstance(res, bool):
return return
@ -275,32 +277,62 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> bool:
return self.agent_config.memory_configs or len(attachments) > 0
async def _retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> List[Message]:
return []
async def _run( async def _run(
self, self,
session: Session,
turn_id: str, turn_id: str,
input_messages: List[Message], input_messages: List[Message],
attachments: List[Attachment], attachments: List[Attachment],
sampling_params: SamplingParams, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
need_context = await self._should_retrieve_context(input_messages, attachments) enabled_tools = set(t.type for t in self.agent_config.tools)
if need_context: need_rag_context = await self._should_retrieve_context(
context_messages = await self._retrieve_context(input_messages) input_messages, attachments
# input_messages = preprocess_dialog(input_messages, self.prefix_messages) )
# input_messages = input_messages + context if need_rag_context:
input_messages = preprocess_dialog(input_messages) step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepStartPayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
)
)
)
attachments = [] # TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
rag_context, bank_ids = await self._retrieve_context(input_messages)
step_id = str(uuid.uuid4())
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseStepCompletePayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
step_details=MemoryRetrievalStep(
memory_bank_ids=bank_ids,
inserted_context=rag_context,
),
)
)
)
if rag_context:
system_message = next(m for m in input_messages if m.role == "system")
if system_message:
system_message.content = system_message.content + "\n" + rag_context
else:
input_messages = [
Message(role="system", content=rag_context)
] + input_messages
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)]
input_messages.append(attachment_message(urls))
output_attachments = []
n_iter = 0 n_iter = 0
while True: while True:
@ -414,7 +446,8 @@ class ChatAgent(ShieldRunnerMixin):
if len(message.tool_calls) == 0: if len(message.tool_calls) == 0:
if stop_reason == StopReason.end_of_turn: if stop_reason == StopReason.end_of_turn:
if len(attachments) > 0: # TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
if isinstance(message.content, list): if isinstance(message.content, list):
message.content += attachments message.content += attachments
else: else:
@ -526,58 +559,131 @@ class ChatAgent(ShieldRunnerMixin):
# NOTE: when we push this message back to the model, the model may ignore the # NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message # attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message # with the summary. We keep all generated attachments and then attach them to final message
attachments.append(result_message.content) output_attachments.append(result_message.content)
elif isinstance(result_message.content, list) or isinstance( elif isinstance(result_message.content, list) or isinstance(
result_message.content, tuple result_message.content, tuple
): ):
for c in result_message.content: for c in result_message.content:
if isinstance(c, Attachment): if isinstance(c, Attachment):
attachments.append(c) output_attachments.append(c)
input_messages = input_messages + [message, result_message] input_messages = input_messages + [message, result_message]
n_iter += 1 n_iter += 1
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
if session.memory_bank is None:
session.memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session.session_id}",
config=VectorMemoryBankConfig(
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
),
)
def attachment_message(url: URL) -> ToolResponseMessage: return session.memory_bank
uri = url.uri
assert uri.startswith("file://") async def _should_retrieve_context(
filepath = uri[len("file://") :] self, messages: List[Message], attachments: List[Attachment]
) -> bool:
enabled_tools = set(t.type for t in self.agent_config.tools)
if attachments:
if (
AgenticSystemTool.code_interpreter.value in enabled_tools
and self.agent_config.tool_choice == ToolChoice.required
):
return False
return attachments or AgenticSystemTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
for t in self.agent_config.tools:
if t.type == AgenticSystemTool.memory.value:
return t
return None
async def _retrieve_context(
self, session: Session, messages: List[Message], attachments: List[Attachment]
) -> Optional[InterleavedTextMedia]:
bank_ids = []
memory = self._memory_tool_definition()
assert memory is not None, "Memory tool not configured"
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
if attachments:
bank = await self._ensure_memory_bank(session)
bank_ids.append(bank.bank_id)
documents = [
MemoryBankDocument(
doc_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in attachments
]
await self.memory_api.insert_documents(bank_id, documents)
assert len(bank_ids) > 0, "No memory banks configured?"
query = " ".join(m.content for m in messages)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
query=query,
params={
"max_chunks": 5,
},
)
for bank_id in bank_ids
]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
)
if not chunks:
return None
tokens = 0
picked = []
for c in chunks[: memory.max_chunks]:
tokens += c.token_count
if tokens > memory.max_tokens_in_context:
cprint(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
"red",
)
break
picked.append(c)
return [
"The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
def attachment_message(urls: List[URL]) -> ToolResponseMessage:
content = []
for url in urls:
uri = url.uri
assert uri.startswith("file://")
filepath = uri[len("file://") :]
content.append(f'# There is a file accessible to you at "{filepath}"\n')
return ToolResponseMessage( return ToolResponseMessage(
call_id="", call_id="",
tool_name=BuiltinTool.code_interpreter, tool_name=BuiltinTool.code_interpreter,
content=f'# There is a file accessible to you at "{filepath}"', content=content,
) )
def preprocess_dialog(messages: List[Message]) -> List[Message]:
# remove system message since those are
"""
Preprocesses the dialog by removing the system message and
adding the system message to the beginning of the dialog.
"""
ret = []
for m in messages:
if m.role == Role.system.value:
continue
# NOTE: the ideal behavior is to use `file_path = ...` but that
# means we need to have stateful execution of code which we currently
# do not have.
if isinstance(m.content, Attachment):
ret.append(attachment_message(m.content.url))
elif isinstance(m.content, list):
for c in m.content:
if isinstance(c, Attachment):
ret.append(attachment_message(c.url))
ret.append(m)
return ret
async def execute_tool_call_maybe( async def execute_tool_call_maybe(
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage] tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
) -> List[ToolResponseMessage]: ) -> List[ToolResponseMessage]:

View file

@ -101,6 +101,11 @@ class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage] completion_message_batch: List[CompletionMessage]
@json_schema_type
class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
class Inference(Protocol): class Inference(Protocol):
@webmethod(route="/inference/completion") @webmethod(route="/inference/completion")
async def completion( async def completion(
@ -114,6 +119,13 @@ class Inference(Protocol):
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
@webmethod(route="/inference/embeddings")
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ...
@webmethod(route="/inference/batch_completion") @webmethod(route="/inference/batch_completion")
async def batch_completion( async def batch_completion(
self, self,

View file

@ -23,17 +23,6 @@ class MemoryBankDocument(BaseModel):
metadata: Dict[str, Any] metadata: Dict[str, Any]
class Chunk(BaseModel):
content: InterleavedTextMedia
token_count: int
@json_schema_type
class QueryDocumentsResponse(BaseModel):
chunks: List[Chunk]
scores: List[float]
@json_schema_type @json_schema_type
class MemoryBankType(Enum): class MemoryBankType(Enum):
vector = "vector" vector = "vector"
@ -45,6 +34,7 @@ class MemoryBankType(Enum):
class VectorMemoryBankConfig(BaseModel): class VectorMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str embedding_model: str
chunk_size_in_tokens: int
class KeyValueMemoryBankConfig(BaseModel): class KeyValueMemoryBankConfig(BaseModel):
@ -70,18 +60,39 @@ MemoryBankConfig = Annotated[
] ]
class Chunk(BaseModel):
content: InterleavedTextMedia
token_count: int
@json_schema_type
class QueryDocumentsResponse(BaseModel):
chunks: List[Chunk]
scores: List[float]
@json_schema_type
class QueryAPI(Protocol):
@webmethod(route="/query_documents")
def query_documents(
self,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@json_schema_type @json_schema_type
class MemoryBank(BaseModel): class MemoryBank(BaseModel):
bank_id: str bank_id: str
name: str name: str
config: MemoryBankConfig config: MemoryBankConfig
# if there's a pre-existing store which obeys the MemoryBank REST interface # if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI
url: Optional[URL] = None url: Optional[URL] = None
class Memory(Protocol): class Memory(Protocol):
@webmethod(route="/memory_banks/create") @webmethod(route="/memory_banks/create")
def create_memory_bank( async def create_memory_bank(
self, self,
name: str, name: str,
config: MemoryBankConfig, config: MemoryBankConfig,
@ -89,33 +100,35 @@ class Memory(Protocol):
) -> MemoryBank: ... ) -> MemoryBank: ...
@webmethod(route="/memory_banks/list", method="GET") @webmethod(route="/memory_banks/list", method="GET")
def list_memory_banks(self) -> List[MemoryBank]: ... async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get") @webmethod(route="/memory_banks/get")
def get_memory_bank(self, bank_id: str) -> MemoryBank: ... async def get_memory_bank(self, bank_id: str) -> MemoryBank: ...
@webmethod(route="/memory_banks/drop", method="DELETE") @webmethod(route="/memory_banks/drop", method="DELETE")
def drop_memory_bank( async def drop_memory_bank(
self, self,
bank_id: str, bank_id: str,
) -> str: ... ) -> str: ...
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@webmethod(route="/memory_bank/insert") @webmethod(route="/memory_bank/insert")
def insert_documents( async def insert_documents(
self, self,
bank_id: str, bank_id: str,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ... ) -> None: ...
@webmethod(route="/memory_bank/update") @webmethod(route="/memory_bank/update")
def update_documents( async def update_documents(
self, self,
bank_id: str, bank_id: str,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ... ) -> None: ...
@webmethod(route="/memory_bank/query") @webmethod(route="/memory_bank/query")
def query_documents( async def query_documents(
self, self,
bank_id: str, bank_id: str,
query: InterleavedTextMedia, query: InterleavedTextMedia,
@ -123,14 +136,14 @@ class Memory(Protocol):
) -> QueryDocumentsResponse: ... ) -> QueryDocumentsResponse: ...
@webmethod(route="/memory_bank/documents/get") @webmethod(route="/memory_bank/documents/get")
def get_documents( async def get_documents(
self, self,
bank_id: str, bank_id: str,
document_ids: List[str], document_ids: List[str],
) -> List[MemoryBankDocument]: ... ) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/documents/delete") @webmethod(route="/memory_bank/documents/delete")
def delete_documents( async def delete_documents(
self, self,
bank_id: str, bank_id: str,
document_ids: List[str], document_ids: List[str],