diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 528246cdf..f9ffb2ae0 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -401,7 +401,9 @@ class ChatAgent(ShieldRunnerMixin): session_info = await self.storage.get_session_info(session_id) # if the session has a memory bank id, let the memory tool use it if session_info.memory_bank_id: - query_args["memory_bank_id"] = session_info.memory_bank_id + if "memory_bank_ids" not in query_args: + query_args["memory_bank_ids"] = [] + query_args["memory_bank_ids"].append(session_info.memory_bank_id) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( diff --git a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py index 1fb1d0992..803981f07 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py @@ -8,8 +8,10 @@ from typing import List from jinja2 import Template +from pydantic import BaseModel -from llama_stack.apis.inference import Message, UserMessage +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import UserMessage from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) @@ -24,7 +26,7 @@ from .config import ( async def generate_rag_query( config: MemoryQueryGeneratorConfig, - messages: List[Message], + messages: List[InterleavedContent], **kwargs, ): """ @@ -42,21 +44,26 @@ async def generate_rag_query( async def default_rag_query_generator( config: DefaultMemoryQueryGeneratorConfig, - messages: List[Message], + messages: List[InterleavedContent], **kwargs, ): - return config.sep.join(interleaved_content_as_str(m.content) for m in messages) + return config.sep.join(interleaved_content_as_str(m) for m in messages) async def llm_rag_query_generator( config: LLMMemoryQueryGeneratorConfig, - messages: List[Message], + messages: List[InterleavedContent], **kwargs, ): assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" inference_api = kwargs["inference_api"] - m_dict = {"messages": [message.model_dump() for message in messages]} + m_dict = { + "messages": [ + message.model_dump() if isinstance(message, BaseModel) else message + for message in messages + ] + } template = Template(config.template) content = template.render(m_dict) diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index a6ce744a6..f46b37510 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -69,7 +69,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): ] async def _retrieve_context( - self, input_messages: List[str], bank_ids: List[str] + self, input_messages: List[InterleavedContent], bank_ids: List[str] ) -> Optional[List[InterleavedContent]]: if not bank_ids: return None diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index a4ad2278f..01ffe2025 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -206,7 +206,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): def test_code_execution(llama_stack_client): agent_config = AgentConfig( - model="meta-llama/Llama-3.1-70B-Instruct", + model="meta-llama/Llama-3.1-8B-Instruct", instructions="You are a helpful assistant", toolgroups=[ "builtin::code_interpreter", @@ -297,7 +297,7 @@ def test_rag_agent(llama_stack_client, agent_config): agent_config["toolgroups"].append( dict( name="builtin::memory", - args={"memory_bank_id": memory_bank_id}, + args={"memory_bank_ids": [memory_bank_id]}, ) ) agent = Agent(llama_stack_client, agent_config)