mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 02:12:26 +00:00
address feedback
This commit is contained in:
parent
ee542a7373
commit
16d1f66f55
9 changed files with 286 additions and 149 deletions
|
|
@ -54,14 +54,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
return []
|
||||
|
||||
async def _retrieve_context(
|
||||
self, messages: List[Message], bank_ids: List[str]
|
||||
self, message: Message, bank_ids: List[str]
|
||||
) -> Optional[List[InterleavedContent]]:
|
||||
if not bank_ids:
|
||||
return None
|
||||
if len(messages) == 0:
|
||||
return None
|
||||
|
||||
message = messages[-1] # only use the last message as input to the query
|
||||
query = await generate_rag_query(
|
||||
self.config.query_generator_config,
|
||||
message,
|
||||
|
|
@ -113,10 +109,15 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
config = MemoryToolConfig()
|
||||
if tool.metadata.get("config") is not None:
|
||||
config = MemoryToolConfig(**tool.metadata["config"])
|
||||
|
||||
if "memory_bank_id" in args:
|
||||
bank_ids = [args["memory_bank_id"]]
|
||||
else:
|
||||
bank_ids = [
|
||||
bank_config.bank_id for bank_config in config.memory_bank_configs
|
||||
]
|
||||
context = await self._retrieve_context(
|
||||
args["input_messages"],
|
||||
[bank_config.bank_id for bank_config in config.memory_bank_configs],
|
||||
args["query"],
|
||||
bank_ids,
|
||||
)
|
||||
if context is None:
|
||||
context = []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue