mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 02:12:26 +00:00
address feedback
This commit is contained in:
parent
ee542a7373
commit
16d1f66f55
9 changed files with 286 additions and 149 deletions
|
|
@ -9,7 +9,7 @@ from typing import Dict, List
|
|||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.agent import Agent, AugmentConfigWithMemoryTool
|
||||
from llama_stack_client.lib.agents.client_tool import ClientTool
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types import ToolResponseMessage
|
||||
|
|
@ -151,11 +151,10 @@ def test_agent_simple(llama_stack_client, agent_config):
|
|||
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"tool_names": [
|
||||
"tools": [
|
||||
"brave_search",
|
||||
],
|
||||
}
|
||||
print(f"Agent Config: {agent_config}")
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
|
|
@ -181,7 +180,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
|||
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"tool_names": [
|
||||
"tools": [
|
||||
"code_interpreter",
|
||||
],
|
||||
}
|
||||
|
|
@ -209,7 +208,7 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
agent_config = {
|
||||
**agent_config,
|
||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"tool_names": ["brave_search"],
|
||||
"tools": ["brave_search"],
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
"tool_prompt_format": "python_list",
|
||||
}
|
||||
|
|
@ -252,8 +251,12 @@ def test_rag_agent(llama_stack_client, agent_config):
|
|||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
agent = Agent.with_memory(llama_stack_client, agent_config)
|
||||
[agent.add_document(document) for document in documents]
|
||||
memory_bank_id = AugmentConfigWithMemoryTool(agent_config, llama_stack_client)
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
llama_stack_client.memory.insert(
|
||||
bank_id=memory_bank_id,
|
||||
documents=documents,
|
||||
)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
user_prompts = [
|
||||
|
|
@ -271,8 +274,16 @@ def test_rag_agent(llama_stack_client, agent_config):
|
|||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
tools=[
|
||||
{
|
||||
"name": "memory",
|
||||
"args": {
|
||||
"memory_bank_id": memory_bank_id,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert "Tool:memory-tool" in logs_str
|
||||
assert "Tool:memory" in logs_str
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue