diff --git a/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py b/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py index 42bffe981..afcc6afd1 100644 --- a/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py +++ b/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py @@ -21,65 +21,56 @@ from llama_toolchain.inference.api import * # noqa: F403 async def generate_rag_query( - generator_config: MemoryQueryGeneratorConfig, + config: MemoryQueryGeneratorConfig, messages: List[Message], **kwargs, ): - if generator_config.type == MemoryQueryGenerator.default.value: - generator = DefaultRAGQueryGenerator(generator_config, **kwargs) - elif generator_config.type == MemoryQueryGenerator.llm.value: - generator = LLMRAGQueryGenerator(generator_config, **kwargs) + """ + Generates a query that will be used for + retrieving relevant information from the memory bank. + """ + if config.type == MemoryQueryGenerator.default.value: + query = await default_rag_query_generator(config, messages, **kwargs) + elif config.type == MemoryQueryGenerator.llm.value: + query = await llm_rag_query_generator(config, messages, **kwargs) else: - raise NotImplementedError( - f"Unsupported memory query generator {generator_config.type}" - ) - - query = await generator.gen(messages) + raise NotImplementedError(f"Unsupported memory query generator {config.type}") # cprint(f"Generated query >>>: {query}", color="green") return query -class DefaultRAGQueryGenerator: - def __init__(self, config: DefaultMemoryQueryGeneratorConfig, **kwargs): - self.config = config +async def default_rag_query_generator( + config: DefaultMemoryQueryGeneratorConfig, + messages: List[Message], + **kwargs, +): + return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages) - async def gen(self, messages: List[Message]) -> InterleavedTextMedia: - query = self.config.sep.join( - interleaved_text_media_as_str(m.content) for m in messages + +async def llm_rag_query_generator( + config: LLMMemoryQueryGeneratorConfig, + messages: List[Message], + **kwargs, +): + assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" + inference_api = kwargs["inference_api"] + + m_dict = {"messages": [m.model_dump() for m in messages]} + + template = Template(config.template) + content = template.render(m_dict) + + model = config.model + message = UserMessage(content=content) + response = inference_api.chat_completion( + ChatCompletionRequest( + model=model, + messages=[message], + stream=False, ) - return query + ) + async for chunk in response: + query = chunk.completion_message.content -class LLMRAGQueryGenerator: - def __init__(self, config: LLMMemoryQueryGeneratorConfig, **kwargs): - self.config = config - assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" - self.inference_api = kwargs["inference_api"] - - async def gen(self, messages: List[Message]) -> InterleavedTextMedia: - """ - Generates a query that will be used for - retrieving relevant information from the memory bank. - """ - # get template from user - # user template will assume data has the format of - # pydantic object representing List[Message] - m_dict = {"messages": [m.model_dump() for m in messages]} - - template = Template(self.config.template) - content = template.render(m_dict) - - model = self.config.model - message = UserMessage(content=content) - response = self.inference_api.chat_completion( - ChatCompletionRequest( - model=model, - messages=[message], - stream=False, - ) - ) - - async for chunk in response: - query = chunk.completion_message.content - - return query + return query