address feedback

This commit is contained in:
Dinesh Yeduguru 2025-01-02 18:42:20 -08:00
parent ee542a7373
commit 16d1f66f55
9 changed files with 286 additions and 149 deletions

View file

@ -54,14 +54,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
return []
async def _retrieve_context(
self, messages: List[Message], bank_ids: List[str]
self, message: Message, bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None
if len(messages) == 0:
return None
message = messages[-1] # only use the last message as input to the query
query = await generate_rag_query(
self.config.query_generator_config,
message,
@ -113,10 +109,15 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
config = MemoryToolConfig()
if tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
if "memory_bank_id" in args:
bank_ids = [args["memory_bank_id"]]
else:
bank_ids = [
bank_config.bank_id for bank_config in config.memory_bank_configs
]
context = await self._retrieve_context(
args["input_messages"],
[bank_config.bank_id for bank_config in config.memory_bank_configs],
args["query"],
bank_ids,
)
if context is None:
context = []