diff --git a/llama_toolchain/agentic_system/api/api.py b/llama_toolchain/agentic_system/api/api.py index e3f417918..68ec980e6 100644 --- a/llama_toolchain/agentic_system/api/api.py +++ b/llama_toolchain/agentic_system/api/api.py @@ -116,10 +116,47 @@ MemoryBankConfig = Annotated[ ] -@json_schema_type +class MemoryQueryGenerator(Enum): + default = "default" + llm = "llm" + custom = "custom" + + +class DefaultMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.default.value] = ( + MemoryQueryGenerator.default.value + ) + sep: str = " " + + +class LLMMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value + model: str + template: str + + +class CustomMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value + + +MemoryQueryGeneratorConfig = Annotated[ + Union[ + DefaultMemoryQueryGeneratorConfig, + LLMMemoryQueryGeneratorConfig, + CustomMemoryQueryGeneratorConfig, + ], + Field(discriminator="type"), +] + + class MemoryToolDefinition(ToolDefinitionCommon): type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) + # This config defines how a query is generated using the messages + # for memory bank retrieval. + query_generator_config: MemoryQueryGeneratorConfig = Field( + default=DefaultMemoryQueryGeneratorConfig() + ) max_tokens_in_context: int = 4096 max_chunks: int = 10 diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index ed3145b1e..4d38e0032 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -31,6 +31,7 @@ from llama_toolchain.tools.builtin import ( SingleMessageBuiltinTool, ) +from .rag.context_retriever import generate_rag_query from .safety import SafetyException, ShieldRunnerMixin @@ -664,7 +665,9 @@ class ChatAgent(ShieldRunnerMixin): # (i.e., no prior turns uploaded an Attachment) return None, [] - query = " ".join(m.content for m in messages) + query = await generate_rag_query( + memory.query_generator_config, messages, inference_api=self.inference_api + ) tasks = [ self.memory_api.query_documents( bank_id=bank_id, diff --git a/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py b/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py new file mode 100644 index 000000000..afcc6afd1 --- /dev/null +++ b/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from jinja2 import Template +from llama_models.llama3.api import * # noqa: F403 + + +from llama_toolchain.agentic_system.api import ( + DefaultMemoryQueryGeneratorConfig, + LLMMemoryQueryGeneratorConfig, + MemoryQueryGenerator, + MemoryQueryGeneratorConfig, +) +from termcolor import cprint # noqa: F401 +from llama_toolchain.inference.api import * # noqa: F403 + + +async def generate_rag_query( + config: MemoryQueryGeneratorConfig, + messages: List[Message], + **kwargs, +): + """ + Generates a query that will be used for + retrieving relevant information from the memory bank. + """ + if config.type == MemoryQueryGenerator.default.value: + query = await default_rag_query_generator(config, messages, **kwargs) + elif config.type == MemoryQueryGenerator.llm.value: + query = await llm_rag_query_generator(config, messages, **kwargs) + else: + raise NotImplementedError(f"Unsupported memory query generator {config.type}") + # cprint(f"Generated query >>>: {query}", color="green") + return query + + +async def default_rag_query_generator( + config: DefaultMemoryQueryGeneratorConfig, + messages: List[Message], + **kwargs, +): + return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages) + + +async def llm_rag_query_generator( + config: LLMMemoryQueryGeneratorConfig, + messages: List[Message], + **kwargs, +): + assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" + inference_api = kwargs["inference_api"] + + m_dict = {"messages": [m.model_dump() for m in messages]} + + template = Template(config.template) + content = template.render(m_dict) + + model = config.model + message = UserMessage(content=content) + response = inference_api.chat_completion( + ChatCompletionRequest( + model=model, + messages=[message], + stream=False, + ) + ) + + async for chunk in response: + query = chunk.completion_message.content + + return query