diff --git a/llama_toolchain/agentic_system/api/api.py b/llama_toolchain/agentic_system/api/api.py index cc3650f5d..68ec980e6 100644 --- a/llama_toolchain/agentic_system/api/api.py +++ b/llama_toolchain/agentic_system/api/api.py @@ -133,8 +133,6 @@ class LLMMemoryQueryGeneratorConfig(BaseModel): type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value model: str template: str - host: str = "localhost" - port: int = 5000 class CustomMemoryQueryGeneratorConfig(BaseModel): @@ -157,7 +155,7 @@ class MemoryToolDefinition(ToolDefinitionCommon): # This config defines how a query is generated using the messages # for memory bank retrieval. query_generator_config: MemoryQueryGeneratorConfig = Field( - default=DefaultMemoryQueryGeneratorConfig + default=DefaultMemoryQueryGeneratorConfig() ) max_tokens_in_context: int = 4096 max_chunks: int = 10 diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 9c2ebf68e..4d38e0032 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -31,7 +31,7 @@ from llama_toolchain.tools.builtin import ( SingleMessageBuiltinTool, ) -from .context_retriever import generate_rag_query +from .rag.context_retriever import generate_rag_query from .safety import SafetyException, ShieldRunnerMixin @@ -665,7 +665,9 @@ class ChatAgent(ShieldRunnerMixin): # (i.e., no prior turns uploaded an Attachment) return None, [] - query = await generate_rag_query(memory.query_generator_config, messages) + query = await generate_rag_query( + memory.query_generator_config, messages, inference_api=self.inference_api + ) tasks = [ self.memory_api.query_documents( bank_id=bank_id, diff --git a/llama_toolchain/agentic_system/meta_reference/context_retriever.py b/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py similarity index 77% rename from llama_toolchain/agentic_system/meta_reference/context_retriever.py rename to llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py index 1514d8b65..42bffe981 100644 --- a/llama_toolchain/agentic_system/meta_reference/context_retriever.py +++ b/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py @@ -10,38 +10,37 @@ from jinja2 import Template from llama_models.llama3.api import * # noqa: F403 -from termcolor import cprint - from llama_toolchain.agentic_system.api import ( DefaultMemoryQueryGeneratorConfig, LLMMemoryQueryGeneratorConfig, MemoryQueryGenerator, MemoryQueryGeneratorConfig, ) +from termcolor import cprint # noqa: F401 from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.inference.client import InferenceClient async def generate_rag_query( generator_config: MemoryQueryGeneratorConfig, messages: List[Message], + **kwargs, ): if generator_config.type == MemoryQueryGenerator.default.value: - generator = DefaultRAGQueryGenerator(generator_config) + generator = DefaultRAGQueryGenerator(generator_config, **kwargs) elif generator_config.type == MemoryQueryGenerator.llm.value: - generator = LLMRAGQueryGenerator(generator_config) + generator = LLMRAGQueryGenerator(generator_config, **kwargs) else: raise NotImplementedError( f"Unsupported memory query generator {generator_config.type}" ) query = await generator.gen(messages) - cprint(f"Generated query >>>: {query}", color="green") + # cprint(f"Generated query >>>: {query}", color="green") return query class DefaultRAGQueryGenerator: - def __init__(self, config: DefaultMemoryQueryGeneratorConfig): + def __init__(self, config: DefaultMemoryQueryGeneratorConfig, **kwargs): self.config = config async def gen(self, messages: List[Message]) -> InterleavedTextMedia: @@ -52,11 +51,12 @@ class DefaultRAGQueryGenerator: class LLMRAGQueryGenerator: - def __init__(self, config: LLMMemoryQueryGeneratorConfig): + def __init__(self, config: LLMMemoryQueryGeneratorConfig, **kwargs): self.config = config + assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" + self.inference_api = kwargs["inference_api"] async def gen(self, messages: List[Message]) -> InterleavedTextMedia: - # params will have """ Generates a query that will be used for retrieving relevant information from the memory bank. @@ -69,15 +69,9 @@ class LLMRAGQueryGenerator: template = Template(self.config.template) content = template.render(m_dict) - cprint(f"Rendered Template >>>: {content}", color="yellow") - # TODO: How to manage these config params better ? - host = self.config.host - port = self.config.port - client = InferenceClient(f"http://{host}:{port}") - model = self.config.model message = UserMessage(content=content) - response = client.chat_completion( + response = self.inference_api.chat_completion( ChatCompletionRequest( model=model, messages=[message],