From 14637bea66012ed243769087d785fb6d25f2d147 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 23 Aug 2024 15:20:40 -0700 Subject: [PATCH] agentic loop has a RAG implementation --- .../agentic_system/api/datatypes.py | 77 ++++--- .../meta_reference/agent_instance.py | 212 +++++++++++++----- llama_toolchain/inference/api/endpoints.py | 12 + llama_toolchain/memory/api/endpoints.py | 55 +++-- 4 files changed, 245 insertions(+), 111 deletions(-) diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py index dbebb9fec..cb99d80fc 100644 --- a/llama_toolchain/agentic_system/api/datatypes.py +++ b/llama_toolchain/agentic_system/api/datatypes.py @@ -13,6 +13,7 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated +from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.common.deployment_types import * # noqa: F403 from llama_toolchain.inference.api import * # noqa: F403 from llama_toolchain.safety.api import * # noqa: F403 @@ -21,7 +22,7 @@ from llama_toolchain.memory.api import * # noqa: F403 @json_schema_type class Attachment(BaseModel): - url: URL + content: InterleavedTextMedia | URL mime_type: str @@ -81,10 +82,45 @@ class FunctionCallToolDefinition(ToolDefinitionCommon): remote_execution: Optional[RestAPIExecutionConfig] = None +class _MemoryBankConfigCommon(BaseModel): + bank_id: str + + +class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon): + type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value + + +class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon): + type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value + keys: List[str] # what keys to focus on + + +class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon): + type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value + + +class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon): + type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value + entities: List[str] # what entities to focus on + + +MemoryBankConfig = Annotated[ + Union[ + AgenticSystemVectorMemoryBankConfig, + AgenticSystemKeyValueMemoryBankConfig, + AgenticSystemKeywordMemoryBankConfig, + AgenticSystemGraphMemoryBankConfig, + ], + Field(discriminator="type"), +] + + @json_schema_type class MemoryToolDefinition(ToolDefinitionCommon): type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value - memory_banks: List[MemoryBank] = Field(default_factory=list) + memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) + max_tokens_in_context: int = 4096 + max_chunks: int = 10 AgenticSystemToolDefinition = Annotated[ @@ -141,8 +177,7 @@ class MemoryRetrievalStep(StepCommon): StepType.memory_retrieval.value ) memory_bank_ids: List[str] - documents: List[MemoryBankDocument] - scores: List[float] + inserted_context: InterleavedTextMedia Step = Annotated[ @@ -185,38 +220,7 @@ class Session(BaseModel): turns: List[Turn] started_at: datetime - -class MemoryBankConfigCommon(BaseModel): - bank_id: str - - -class VectorMemoryBankConfig(MemoryBankConfigCommon): - type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value - - -class KeyValueMemoryBankConfig(MemoryBankConfigCommon): - type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value - keys: List[str] # what keys to focus on - - -class KeywordMemoryBankConfig(MemoryBankConfigCommon): - type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value - - -class GraphMemoryBankConfig(MemoryBankConfigCommon): - type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value - entities: List[str] # what entities to focus on - - -MemoryBankConfig = Annotated[ - Union[ - VectorMemoryBankConfig, - KeyValueMemoryBankConfig, - KeywordMemoryBankConfig, - GraphMemoryBankConfig, - ], - Field(discriminator="type"), -] + memory_bank: Optional[MemoryBank] = None class AgentConfigCommon(BaseModel): @@ -236,7 +240,6 @@ class AgentConfigCommon(BaseModel): class AgentConfig(AgentConfigCommon): model: str instructions: str - memory_bank_configs: Optional[List[MemoryBankConfig]] = Field(default_factory=list) class AgentConfigOverridablePerTurn(AgentConfigCommon): diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 947b94a0f..0cb5c3a0e 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -119,6 +119,7 @@ class ChatAgent(ShieldRunnerMixin): steps = [] output_message = None async for chunk in self.run( + session=session, turn_id=turn_id, input_messages=messages, attachments=request.attachments or [], @@ -170,6 +171,7 @@ class ChatAgent(ShieldRunnerMixin): async def run( self, + session: Session, turn_id: str, input_messages: List[Message], attachments: List[Attachment], @@ -190,7 +192,7 @@ class ChatAgent(ShieldRunnerMixin): yield res async for res in self._run( - turn_id, input_messages, attachments, sampling_params, stream + turn_id, session, input_messages, attachments, sampling_params, stream ): if isinstance(res, bool): return @@ -275,32 +277,62 @@ class ChatAgent(ShieldRunnerMixin): ) ) - async def _should_retrieve_context( - self, messages: List[Message], attachments: List[Attachment] - ) -> bool: - return self.agent_config.memory_configs or len(attachments) > 0 - - async def _retrieve_context( - self, messages: List[Message], attachments: List[Attachment] - ) -> List[Message]: - return [] - async def _run( self, + session: Session, turn_id: str, input_messages: List[Message], attachments: List[Attachment], sampling_params: SamplingParams, stream: bool = False, ) -> AsyncGenerator: - need_context = await self._should_retrieve_context(input_messages, attachments) - if need_context: - context_messages = await self._retrieve_context(input_messages) - # input_messages = preprocess_dialog(input_messages, self.prefix_messages) - # input_messages = input_messages + context - input_messages = preprocess_dialog(input_messages) + enabled_tools = set(t.type for t in self.agent_config.tools) + need_rag_context = await self._should_retrieve_context( + input_messages, attachments + ) + if need_rag_context: + step_id = str(uuid.uuid4()) + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepStartPayload( + step_type=StepType.memory_retrieval.value, + step_id=step_id, + ) + ) + ) - attachments = [] + # TODO: find older context from the session and either replace it + # or append with a sliding window. this is really a very simplistic implementation + rag_context, bank_ids = await self._retrieve_context(input_messages) + + step_id = str(uuid.uuid4()) + yield AgenticSystemTurnResponseStreamChunk( + event=AgenticSystemTurnResponseEvent( + payload=AgenticSystemTurnResponseStepCompletePayload( + step_type=StepType.memory_retrieval.value, + step_id=step_id, + step_details=MemoryRetrievalStep( + memory_bank_ids=bank_ids, + inserted_context=rag_context, + ), + ) + ) + ) + + if rag_context: + system_message = next(m for m in input_messages if m.role == "system") + if system_message: + system_message.content = system_message.content + "\n" + rag_context + else: + input_messages = [ + Message(role="system", content=rag_context) + ] + input_messages + + elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools: + urls = [a.content for a in attachments if isinstance(a.content, URL)] + input_messages.append(attachment_message(urls)) + + output_attachments = [] n_iter = 0 while True: @@ -414,7 +446,8 @@ class ChatAgent(ShieldRunnerMixin): if len(message.tool_calls) == 0: if stop_reason == StopReason.end_of_turn: - if len(attachments) > 0: + # TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS) + if len(output_attachments) > 0: if isinstance(message.content, list): message.content += attachments else: @@ -526,58 +559,131 @@ class ChatAgent(ShieldRunnerMixin): # NOTE: when we push this message back to the model, the model may ignore the # attached file path etc. since the model is trained to only provide a user message # with the summary. We keep all generated attachments and then attach them to final message - attachments.append(result_message.content) + output_attachments.append(result_message.content) elif isinstance(result_message.content, list) or isinstance( result_message.content, tuple ): for c in result_message.content: if isinstance(c, Attachment): - attachments.append(c) + output_attachments.append(c) input_messages = input_messages + [message, result_message] n_iter += 1 + async def _ensure_memory_bank(self, session: Session) -> MemoryBank: + if session.memory_bank is None: + session.memory_bank = await self.memory_api.create_memory_bank( + name=f"memory_bank_{session.session_id}", + config=VectorMemoryBankConfig( + embedding_model="sentence-transformer/all-MiniLM-L6-v2", + ), + ) -def attachment_message(url: URL) -> ToolResponseMessage: - uri = url.uri - assert uri.startswith("file://") - filepath = uri[len("file://") :] + return session.memory_bank + + async def _should_retrieve_context( + self, messages: List[Message], attachments: List[Attachment] + ) -> bool: + enabled_tools = set(t.type for t in self.agent_config.tools) + if attachments: + if ( + AgenticSystemTool.code_interpreter.value in enabled_tools + and self.agent_config.tool_choice == ToolChoice.required + ): + return False + return attachments or AgenticSystemTool.memory.value in enabled_tools + + def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]: + for t in self.agent_config.tools: + if t.type == AgenticSystemTool.memory.value: + return t + + return None + + async def _retrieve_context( + self, session: Session, messages: List[Message], attachments: List[Attachment] + ) -> Optional[InterleavedTextMedia]: + bank_ids = [] + + memory = self._memory_tool_definition() + assert memory is not None, "Memory tool not configured" + bank_ids.extend(c.bank_id for c in memory.memory_bank_configs) + + if attachments: + bank = await self._ensure_memory_bank(session) + bank_ids.append(bank.bank_id) + + documents = [ + MemoryBankDocument( + doc_id=str(uuid.uuid4()), + content=a.content, + mime_type=a.mime_type, + metadata={}, + ) + for a in attachments + ] + await self.memory_api.insert_documents(bank_id, documents) + + assert len(bank_ids) > 0, "No memory banks configured?" + + query = " ".join(m.content for m in messages) + tasks = [ + self.memory_api.query_documents( + bank_id=bank_id, + query=query, + params={ + "max_chunks": 5, + }, + ) + for bank_id in bank_ids + ] + results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks) + chunks = [c for r in results for c in r.chunks] + scores = [s for r in results for s in r.scores] + + # sort by score + chunks, scores = zip( + *sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True) + ) + if not chunks: + return None + + tokens = 0 + picked = [] + for c in chunks[: memory.max_chunks]: + tokens += c.token_count + if tokens > memory.max_tokens_in_context: + cprint( + f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", + "red", + ) + break + picked.append(c) + + return [ + "The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n", + *picked, + "\n=== END-RETRIEVED-CONTEXT ===\n", + ] + + +def attachment_message(urls: List[URL]) -> ToolResponseMessage: + content = [] + + for url in urls: + uri = url.uri + assert uri.startswith("file://") + filepath = uri[len("file://") :] + content.append(f'# There is a file accessible to you at "{filepath}"\n') return ToolResponseMessage( call_id="", tool_name=BuiltinTool.code_interpreter, - content=f'# There is a file accessible to you at "{filepath}"', + content=content, ) -def preprocess_dialog(messages: List[Message]) -> List[Message]: - # remove system message since those are - """ - Preprocesses the dialog by removing the system message and - adding the system message to the beginning of the dialog. - """ - ret = [] - - for m in messages: - if m.role == Role.system.value: - continue - - # NOTE: the ideal behavior is to use `file_path = ...` but that - # means we need to have stateful execution of code which we currently - # do not have. - if isinstance(m.content, Attachment): - ret.append(attachment_message(m.content.url)) - elif isinstance(m.content, list): - for c in m.content: - if isinstance(c, Attachment): - ret.append(attachment_message(c.url)) - - ret.append(m) - - return ret - - async def execute_tool_call_maybe( tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage] ) -> List[ToolResponseMessage]: diff --git a/llama_toolchain/inference/api/endpoints.py b/llama_toolchain/inference/api/endpoints.py index a4c4d4095..f09c0e3f8 100644 --- a/llama_toolchain/inference/api/endpoints.py +++ b/llama_toolchain/inference/api/endpoints.py @@ -101,6 +101,11 @@ class BatchChatCompletionResponse(BaseModel): completion_message_batch: List[CompletionMessage] +@json_schema_type +class EmbeddingsResponse(BaseModel): + embeddings: List[List[float]] + + class Inference(Protocol): @webmethod(route="/inference/completion") async def completion( @@ -114,6 +119,13 @@ class Inference(Protocol): request: ChatCompletionRequest, ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... + @webmethod(route="/inference/embeddings") + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: ... + @webmethod(route="/inference/batch_completion") async def batch_completion( self, diff --git a/llama_toolchain/memory/api/endpoints.py b/llama_toolchain/memory/api/endpoints.py index 29c2c889e..615014b55 100644 --- a/llama_toolchain/memory/api/endpoints.py +++ b/llama_toolchain/memory/api/endpoints.py @@ -23,17 +23,6 @@ class MemoryBankDocument(BaseModel): metadata: Dict[str, Any] -class Chunk(BaseModel): - content: InterleavedTextMedia - token_count: int - - -@json_schema_type -class QueryDocumentsResponse(BaseModel): - chunks: List[Chunk] - scores: List[float] - - @json_schema_type class MemoryBankType(Enum): vector = "vector" @@ -45,6 +34,7 @@ class MemoryBankType(Enum): class VectorMemoryBankConfig(BaseModel): type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value embedding_model: str + chunk_size_in_tokens: int class KeyValueMemoryBankConfig(BaseModel): @@ -70,18 +60,39 @@ MemoryBankConfig = Annotated[ ] +class Chunk(BaseModel): + content: InterleavedTextMedia + token_count: int + + +@json_schema_type +class QueryDocumentsResponse(BaseModel): + chunks: List[Chunk] + scores: List[float] + + +@json_schema_type +class QueryAPI(Protocol): + @webmethod(route="/query_documents") + def query_documents( + self, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: ... + + @json_schema_type class MemoryBank(BaseModel): bank_id: str name: str config: MemoryBankConfig - # if there's a pre-existing store which obeys the MemoryBank REST interface + # if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI url: Optional[URL] = None class Memory(Protocol): @webmethod(route="/memory_banks/create") - def create_memory_bank( + async def create_memory_bank( self, name: str, config: MemoryBankConfig, @@ -89,33 +100,35 @@ class Memory(Protocol): ) -> MemoryBank: ... @webmethod(route="/memory_banks/list", method="GET") - def list_memory_banks(self) -> List[MemoryBank]: ... + async def list_memory_banks(self) -> List[MemoryBank]: ... @webmethod(route="/memory_banks/get") - def get_memory_bank(self, bank_id: str) -> MemoryBank: ... + async def get_memory_bank(self, bank_id: str) -> MemoryBank: ... @webmethod(route="/memory_banks/drop", method="DELETE") - def drop_memory_bank( + async def drop_memory_bank( self, bank_id: str, ) -> str: ... + # this will just block now until documents are inserted, but it should + # probably return a Job instance which can be polled for completion @webmethod(route="/memory_bank/insert") - def insert_documents( + async def insert_documents( self, bank_id: str, documents: List[MemoryBankDocument], ) -> None: ... @webmethod(route="/memory_bank/update") - def update_documents( + async def update_documents( self, bank_id: str, documents: List[MemoryBankDocument], ) -> None: ... @webmethod(route="/memory_bank/query") - def query_documents( + async def query_documents( self, bank_id: str, query: InterleavedTextMedia, @@ -123,14 +136,14 @@ class Memory(Protocol): ) -> QueryDocumentsResponse: ... @webmethod(route="/memory_bank/documents/get") - def get_documents( + async def get_documents( self, bank_id: str, document_ids: List[str], ) -> List[MemoryBankDocument]: ... @webmethod(route="/memory_bank/documents/delete") - def delete_documents( + async def delete_documents( self, bank_id: str, document_ids: List[str],