Query generators for RAG query (#54)

* Query generators for rag query

* use agent.inference_api instead of passing host/port again

* drop classes for functions

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
Hardik Shah 2024-09-06 13:10:39 -07:00 committed by GitHub
parent 406c3b24d4
commit 5de6ed946e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 118 additions and 2 deletions

View file

@ -31,6 +31,7 @@ from llama_toolchain.tools.builtin import (
SingleMessageBuiltinTool,
)
from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin
@ -664,7 +665,9 @@ class ChatAgent(ShieldRunnerMixin):
# (i.e., no prior turns uploaded an Attachment)
return None, []
query = " ".join(m.content for m in messages)
query = await generate_rag_query(
memory.query_generator_config, messages, inference_api=self.inference_api
)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,