diff --git a/llama_toolchain/agentic_system/api/api.py b/llama_toolchain/agentic_system/api/api.py index e3f417918..cc3650f5d 100644 --- a/llama_toolchain/agentic_system/api/api.py +++ b/llama_toolchain/agentic_system/api/api.py @@ -116,10 +116,49 @@ MemoryBankConfig = Annotated[ ] -@json_schema_type +class MemoryQueryGenerator(Enum): + default = "default" + llm = "llm" + custom = "custom" + + +class DefaultMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.default.value] = ( + MemoryQueryGenerator.default.value + ) + sep: str = " " + + +class LLMMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value + model: str + template: str + host: str = "localhost" + port: int = 5000 + + +class CustomMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value + + +MemoryQueryGeneratorConfig = Annotated[ + Union[ + DefaultMemoryQueryGeneratorConfig, + LLMMemoryQueryGeneratorConfig, + CustomMemoryQueryGeneratorConfig, + ], + Field(discriminator="type"), +] + + class MemoryToolDefinition(ToolDefinitionCommon): type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) + # This config defines how a query is generated using the messages + # for memory bank retrieval. + query_generator_config: MemoryQueryGeneratorConfig = Field( + default=DefaultMemoryQueryGeneratorConfig + ) max_tokens_in_context: int = 4096 max_chunks: int = 10 diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index ed3145b1e..9c2ebf68e 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -31,6 +31,7 @@ from llama_toolchain.tools.builtin import ( SingleMessageBuiltinTool, ) +from .context_retriever import generate_rag_query from .safety import SafetyException, ShieldRunnerMixin @@ -664,7 +665,7 @@ class ChatAgent(ShieldRunnerMixin): # (i.e., no prior turns uploaded an Attachment) return None, [] - query = " ".join(m.content for m in messages) + query = await generate_rag_query(memory.query_generator_config, messages) tasks = [ self.memory_api.query_documents( bank_id=bank_id, diff --git a/llama_toolchain/agentic_system/meta_reference/context_retriever.py b/llama_toolchain/agentic_system/meta_reference/context_retriever.py new file mode 100644 index 000000000..1514d8b65 --- /dev/null +++ b/llama_toolchain/agentic_system/meta_reference/context_retriever.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from jinja2 import Template +from llama_models.llama3.api import * # noqa: F403 + + +from termcolor import cprint + +from llama_toolchain.agentic_system.api import ( + DefaultMemoryQueryGeneratorConfig, + LLMMemoryQueryGeneratorConfig, + MemoryQueryGenerator, + MemoryQueryGeneratorConfig, +) +from llama_toolchain.inference.api import * # noqa: F403 +from llama_toolchain.inference.client import InferenceClient + + +async def generate_rag_query( + generator_config: MemoryQueryGeneratorConfig, + messages: List[Message], +): + if generator_config.type == MemoryQueryGenerator.default.value: + generator = DefaultRAGQueryGenerator(generator_config) + elif generator_config.type == MemoryQueryGenerator.llm.value: + generator = LLMRAGQueryGenerator(generator_config) + else: + raise NotImplementedError( + f"Unsupported memory query generator {generator_config.type}" + ) + + query = await generator.gen(messages) + cprint(f"Generated query >>>: {query}", color="green") + return query + + +class DefaultRAGQueryGenerator: + def __init__(self, config: DefaultMemoryQueryGeneratorConfig): + self.config = config + + async def gen(self, messages: List[Message]) -> InterleavedTextMedia: + query = self.config.sep.join( + interleaved_text_media_as_str(m.content) for m in messages + ) + return query + + +class LLMRAGQueryGenerator: + def __init__(self, config: LLMMemoryQueryGeneratorConfig): + self.config = config + + async def gen(self, messages: List[Message]) -> InterleavedTextMedia: + # params will have + """ + Generates a query that will be used for + retrieving relevant information from the memory bank. + """ + # get template from user + # user template will assume data has the format of + # pydantic object representing List[Message] + m_dict = {"messages": [m.model_dump() for m in messages]} + + template = Template(self.config.template) + content = template.render(m_dict) + + cprint(f"Rendered Template >>>: {content}", color="yellow") + # TODO: How to manage these config params better ? + host = self.config.host + port = self.config.port + client = InferenceClient(f"http://{host}:{port}") + + model = self.config.model + message = UserMessage(content=content) + response = client.chat_completion( + ChatCompletionRequest( + model=model, + messages=[message], + stream=False, + ) + ) + + async for chunk in response: + query = chunk.completion_message.content + + return query