From 439f52b0670c77c4fd69e06e6d3b8b59d79bad45 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 26 Dec 2024 16:02:41 -0800 Subject: [PATCH] register toolgroup as part of test --- tests/client-sdk/agents/test_agents.py | 32 ++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 7f8b5b26b..36674631b 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -17,6 +17,8 @@ from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.custom_tool_def import Parameter from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage +from llama_stack_client.types.tool_def_param import CustomToolDefParam +from llama_stack_client.types.tool_group_def_param import UserDefinedToolGroupDef class TestCustomTool(CustomTool): @@ -268,6 +270,36 @@ def test_rag_agent(llama_stack_client, agent_config): documents=documents, ) + # create the required memory tool + llama_stack_client.toolgroups.register( + tool_group_id="memory_group", + tool_group=UserDefinedToolGroupDef( + type="user_defined", + tools=[ + CustomToolDefParam( + type="custom", + name="memory-tool", + description="Memory tool to retrieve memory from a memory bank based on context of the input messages and attachments", + parameters=[ + Parameter( + name="input_messages", + description="Input messages for which to retrieve memory", + required=True, + parameter_type="list", + ), + ], + metadata={ + "config": { + "memory_bank_configs": [ + {"bank_id": "test_bank", "type": "vector"} + ] + } + }, + ) + ], + ), + provider_id="memory-runtime", + ) agent_config = { **agent_config, "preprocessing_tools": ["memory-tool"],