register toolgroup as part of test

This commit is contained in:
Dinesh Yeduguru 2024-12-26 16:02:41 -08:00
parent f408fd3aca
commit 439f52b067

View file

@ -17,6 +17,8 @@ from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.custom_tool_def import Parameter
from llama_stack_client.types.memory_insert_params import Document
from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.tool_def_param import CustomToolDefParam
from llama_stack_client.types.tool_group_def_param import UserDefinedToolGroupDef
class TestCustomTool(CustomTool):
@ -268,6 +270,36 @@ def test_rag_agent(llama_stack_client, agent_config):
documents=documents,
)
# create the required memory tool
llama_stack_client.toolgroups.register(
tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef(
type="user_defined",
tools=[
CustomToolDefParam(
type="custom",
name="memory-tool",
description="Memory tool to retrieve memory from a memory bank based on context of the input messages and attachments",
parameters=[
Parameter(
name="input_messages",
description="Input messages for which to retrieve memory",
required=True,
parameter_type="list",
),
],
metadata={
"config": {
"memory_bank_configs": [
{"bank_id": "test_bank", "type": "vector"}
]
}
},
)
],
),
provider_id="memory-runtime",
)
agent_config = {
**agent_config,
"preprocessing_tools": ["memory-tool"],