mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
update the client tests to use Agent.with_memory
This commit is contained in:
parent
70b2a58bef
commit
8bf3f8ea56
1 changed files with 2 additions and 53 deletions
|
@ -17,8 +17,6 @@ from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
from llama_stack_client.types.custom_tool_def import Parameter
|
from llama_stack_client.types.custom_tool_def import Parameter
|
||||||
from llama_stack_client.types.memory_insert_params import Document
|
from llama_stack_client.types.memory_insert_params import Document
|
||||||
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||||
from llama_stack_client.types.tool_def_param import CustomToolDefParam
|
|
||||||
from llama_stack_client.types.tool_group_def_param import UserDefinedToolGroupDef
|
|
||||||
|
|
||||||
|
|
||||||
class TestCustomTool(CustomTool):
|
class TestCustomTool(CustomTool):
|
||||||
|
@ -253,58 +251,9 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
)
|
)
|
||||||
for i, url in enumerate(urls)
|
for i, url in enumerate(urls)
|
||||||
]
|
]
|
||||||
llama_stack_client.memory_banks.register(
|
|
||||||
memory_bank_id="test_bank",
|
|
||||||
params={
|
|
||||||
"memory_bank_type": "vector",
|
|
||||||
"embedding_model": "all-MiniLM-L6-v2",
|
|
||||||
"chunk_size_in_tokens": 512,
|
|
||||||
"overlap_size_in_tokens": 64,
|
|
||||||
},
|
|
||||||
provider_id="faiss",
|
|
||||||
)
|
|
||||||
|
|
||||||
# insert some documents
|
agent = Agent.with_memory(llama_stack_client, agent_config)
|
||||||
llama_stack_client.memory.insert(
|
[agent.add_document(document) for document in documents]
|
||||||
bank_id="test_bank",
|
|
||||||
documents=documents,
|
|
||||||
)
|
|
||||||
|
|
||||||
# create the required memory tool
|
|
||||||
llama_stack_client.toolgroups.register(
|
|
||||||
tool_group_id="memory_group",
|
|
||||||
tool_group=UserDefinedToolGroupDef(
|
|
||||||
type="user_defined",
|
|
||||||
tools=[
|
|
||||||
CustomToolDefParam(
|
|
||||||
type="custom",
|
|
||||||
name="memory-tool",
|
|
||||||
description="Memory tool to retrieve memory from a memory bank based on context of the input messages and attachments",
|
|
||||||
parameters=[
|
|
||||||
Parameter(
|
|
||||||
name="input_messages",
|
|
||||||
description="Input messages for which to retrieve memory",
|
|
||||||
required=True,
|
|
||||||
parameter_type="list",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
metadata={
|
|
||||||
"config": {
|
|
||||||
"memory_bank_configs": [
|
|
||||||
{"bank_id": "test_bank", "type": "vector"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
provider_id="memory-runtime",
|
|
||||||
)
|
|
||||||
agent_config = {
|
|
||||||
**agent_config,
|
|
||||||
"preprocessing_tools": ["memory-tool"],
|
|
||||||
}
|
|
||||||
agent = Agent(llama_stack_client, agent_config)
|
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue