update the client tests to use Agent.with_memory

This commit is contained in:
Dinesh Yeduguru 2024-12-30 14:46:57 -08:00
parent 70b2a58bef
commit 8bf3f8ea56

View file

@ -17,8 +17,6 @@ from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.custom_tool_def import Parameter
from llama_stack_client.types.memory_insert_params import Document
from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.tool_def_param import CustomToolDefParam
from llama_stack_client.types.tool_group_def_param import UserDefinedToolGroupDef
class TestCustomTool(CustomTool):
@ -253,58 +251,9 @@ def test_rag_agent(llama_stack_client, agent_config):
)
for i, url in enumerate(urls)
]
llama_stack_client.memory_banks.register(
memory_bank_id="test_bank",
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
provider_id="faiss",
)
# insert some documents
llama_stack_client.memory.insert(
bank_id="test_bank",
documents=documents,
)
# create the required memory tool
llama_stack_client.toolgroups.register(
tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef(
type="user_defined",
tools=[
CustomToolDefParam(
type="custom",
name="memory-tool",
description="Memory tool to retrieve memory from a memory bank based on context of the input messages and attachments",
parameters=[
Parameter(
name="input_messages",
description="Input messages for which to retrieve memory",
required=True,
parameter_type="list",
),
],
metadata={
"config": {
"memory_bank_configs": [
{"bank_id": "test_bank", "type": "vector"}
]
}
},
)
],
),
provider_id="memory-runtime",
)
agent_config = {
**agent_config,
"preprocessing_tools": ["memory-tool"],
}
agent = Agent(llama_stack_client, agent_config)
agent = Agent.with_memory(llama_stack_client, agent_config)
[agent.add_document(document) for document in documents]
session_id = agent.create_session(f"test-session-{uuid4()}")
user_prompts = [