diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index cfe839dad..528246cdf 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -402,7 +402,6 @@ class ChatAgent(ShieldRunnerMixin): # if the session has a memory bank id, let the memory tool use it if session_info.memory_bank_id: query_args["memory_bank_id"] = session_info.memory_bank_id - serialized_args = tracing.serialize_value(query_args) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -412,8 +411,8 @@ class ChatAgent(ShieldRunnerMixin): parse_status=ToolCallParseStatus.success, content=ToolCall( call_id="", - tool_name="memory", - arguments=serialized_args, + tool_name=MEMORY_QUERY_TOOL, + arguments={}, ), ), ) @@ -435,14 +434,14 @@ class ChatAgent(ShieldRunnerMixin): tool_calls=[ ToolCall( call_id="", - tool_name="memory", - arguments=serialized_args, + tool_name=MEMORY_QUERY_TOOL, + arguments={}, ) ], tool_responses=[ ToolResponse( call_id="", - tool_name="memory", + tool_name=MEMORY_QUERY_TOOL, content=result.content, ) ], @@ -456,7 +455,7 @@ class ChatAgent(ShieldRunnerMixin): span.set_attribute("output", result.content) span.set_attribute("error_code", result.error_code) span.set_attribute("error_message", result.error_message) - span.set_attribute("tool_name", "memory") + span.set_attribute("tool_name", MEMORY_QUERY_TOOL) if result.error_code == 0: last_message = input_messages[-1] last_message.context = result.content diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index f27cb9dd4..a6ce744a6 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -56,7 +56,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): ) -> List[ToolDef]: return [ ToolDef( - name="memory", + name="query_memory", description="Retrieve context from memory", parameters=[ ToolParameter( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 09cedced3..a4ad2278f 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -101,7 +101,7 @@ def agent_config(llama_stack_client): "temperature": 1.0, "top_p": 0.9, }, - tools=[], + toolgroups=[], tool_choice="auto", tool_prompt_format="json", input_shields=available_shields, @@ -152,8 +152,8 @@ def test_agent_simple(llama_stack_client, agent_config): def test_builtin_tool_web_search(llama_stack_client, agent_config): agent_config = { **agent_config, - "tools": [ - "builtin::web_search", + "toolgroups": [ + "builtin::websearch", ], } agent = Agent(llama_stack_client, agent_config) @@ -181,7 +181,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config): def test_builtin_tool_code_execution(llama_stack_client, agent_config): agent_config = { **agent_config, - "tools": [ + "toolgroups": [ "builtin::code_interpreter", ], } @@ -208,7 +208,7 @@ def test_code_execution(llama_stack_client): agent_config = AgentConfig( model="meta-llama/Llama-3.1-70B-Instruct", instructions="You are a helpful assistant", - tools=[ + toolgroups=[ "builtin::code_interpreter", ], tool_choice="required", @@ -250,7 +250,7 @@ def test_custom_tool(llama_stack_client, agent_config): agent_config = { **agent_config, "model": "meta-llama/Llama-3.2-3B-Instruct", - "tools": ["builtin::web_search"], + "toolgroups": ["builtin::websearch"], "client_tools": [client_tool.get_tool_definition()], "tool_prompt_format": "python_list", } @@ -293,9 +293,14 @@ def test_rag_agent(llama_stack_client, agent_config): for i, url in enumerate(urls) ] - agent_config["tools"].append("builtin::memory") - agent = Agent(llama_stack_client, agent_config) memory_bank_id = "test-memory-bank" + agent_config["toolgroups"].append( + dict( + name="builtin::memory", + args={"memory_bank_id": memory_bank_id}, + ) + ) + agent = Agent(llama_stack_client, agent_config) llama_stack_client.memory_banks.register( memory_bank_id=memory_bank_id, params={ @@ -326,16 +331,8 @@ def test_rag_agent(llama_stack_client, agent_config): } ], session_id=session_id, - tools=[ - { - "name": "memory", - "args": { - "memory_bank_id": memory_bank_id, - }, - } - ], ) logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) - assert "Tool:memory" in logs_str + assert "Tool:query_memory" in logs_str