forked from phoenix-oss/llama-stack-mirror
agents to use tools api (#673)
# What does this PR do? PR #639 introduced the notion of Tools API and ability to invoke tools through API just as any resource. This PR changes the Agents to start using the Tools API to invoke tools. Major changes include: 1) Ability to specify tool groups with AgentConfig 2) Agent gets the corresponding tool definitions for the specified tools and pass along to the model 3) Attachements are now named as Documents and their behavior is mostly unchanged from user perspective 4) You can specify args that can be injected to a tool call through Agent config. This is especially useful in case of memory tool, where you want the tool to operate on a specific memory bank. 5) You can also register tool groups with args, which lets the agent inject these as well into the tool call. 6) All tests have been migrated to use new tools API and fixtures including client SDK tests 7) Telemetry just works with tools API because of our trace protocol decorator ## Test Plan ``` pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct pytest -s -v -k together llama_stack/providers/tests/tools/test_tools.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py ``` run.yaml: https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994 Notebook: https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
This commit is contained in:
parent
596afc6497
commit
a5c57cd381
116 changed files with 4959 additions and 2778 deletions
|
@ -9,24 +9,21 @@ from typing import Dict, List
|
|||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from llama_stack.providers.tests.env import get_env_or_fail
|
||||
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
|
||||
from llama_stack_client.lib.agents.custom_tool import CustomTool
|
||||
from llama_stack_client.lib.agents.client_tool import ClientTool
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types import CompletionMessage, ToolResponseMessage
|
||||
from llama_stack_client.types import ToolResponseMessage
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types.tool_param_definition_param import (
|
||||
ToolParamDefinitionParam,
|
||||
)
|
||||
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||
from llama_stack_client.types.tool_def_param import Parameter
|
||||
|
||||
|
||||
class TestCustomTool(CustomTool):
|
||||
class TestClientTool(ClientTool):
|
||||
"""Tool to give boiling point of a liquid
|
||||
Returns the correct value for water in Celcius and Fahrenheit
|
||||
Returns the correct value for polyjuice in Celcius and Fahrenheit
|
||||
and returns -1 for other liquids
|
||||
|
||||
"""
|
||||
|
||||
def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
||||
|
@ -54,15 +51,19 @@ class TestCustomTool(CustomTool):
|
|||
return "get_boiling_point"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
|
||||
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
|
||||
|
||||
def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:
|
||||
def get_params_definition(self) -> Dict[str, Parameter]:
|
||||
return {
|
||||
"liquid_name": ToolParamDefinitionParam(
|
||||
param_type="string", description="The name of the liquid", required=True
|
||||
"liquid_name": Parameter(
|
||||
name="liquid_name",
|
||||
parameter_type="string",
|
||||
description="The name of the liquid",
|
||||
required=True,
|
||||
),
|
||||
"celcius": ToolParamDefinitionParam(
|
||||
param_type="boolean",
|
||||
"celcius": Parameter(
|
||||
name="celcius",
|
||||
parameter_type="boolean",
|
||||
description="Whether to return the boiling point in Celcius",
|
||||
required=False,
|
||||
),
|
||||
|
@ -100,7 +101,7 @@ def agent_config(llama_stack_client):
|
|||
"temperature": 1.0,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
tools=[],
|
||||
toolgroups=[],
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="json",
|
||||
input_shields=available_shields,
|
||||
|
@ -148,18 +149,13 @@ def test_agent_simple(llama_stack_client, agent_config):
|
|||
assert "I can't" in logs_str
|
||||
|
||||
|
||||
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
||||
def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"tools": [
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "brave",
|
||||
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
|
||||
}
|
||||
"toolgroups": [
|
||||
"builtin::websearch",
|
||||
],
|
||||
}
|
||||
print(f"Agent Config: {agent_config}")
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
|
@ -167,7 +163,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Search the web and tell me who the 44th president of the United States was. Please use tools",
|
||||
"content": "Search the web and tell me who the current CEO of Meta is.",
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
|
@ -178,18 +174,15 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
|||
|
||||
assert "tool_execution>" in logs_str
|
||||
assert "Tool:brave_search Response:" in logs_str
|
||||
assert "obama" in logs_str.lower()
|
||||
if len(agent_config["input_shields"]) > 0:
|
||||
assert "No Violation" in logs_str
|
||||
assert "mark zuckerberg" in logs_str.lower()
|
||||
assert "No Violation" in logs_str
|
||||
|
||||
|
||||
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"tools": [
|
||||
{
|
||||
"type": "code_interpreter",
|
||||
}
|
||||
"toolgroups": [
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
|
@ -199,7 +192,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Write code to answer the question: What is the 100th prime number?",
|
||||
"content": "Write code and execute it to find the answer for: What is the 100th prime number?",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
|
@ -207,50 +200,62 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
|||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
|
||||
if "Tool:code_interpreter Response" not in logs_str:
|
||||
assert len(logs_str) > 0
|
||||
pytest.skip("code_interpreter not called by model")
|
||||
|
||||
assert "541" in logs_str
|
||||
assert "Tool:code_interpreter Response" in logs_str
|
||||
if "No such file or directory: 'bwrap'" in logs_str:
|
||||
assert "prime" in logs_str
|
||||
pytest.skip("`bwrap` is not available on this platform")
|
||||
else:
|
||||
assert "541" in logs_str
|
||||
|
||||
|
||||
def test_code_execution(llama_stack_client):
|
||||
agent_config = AgentConfig(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
toolgroups=[
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
tool_choice="required",
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
||||
codex_agent = Agent(llama_stack_client, agent_config)
|
||||
session_id = codex_agent.create_session("test-session")
|
||||
inflation_doc = AgentDocument(
|
||||
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||
mime_type="text/csv",
|
||||
)
|
||||
|
||||
user_input = [
|
||||
{"prompt": "Here is a csv, can you describe it?", "documents": [inflation_doc]},
|
||||
{"prompt": "Plot average yearly inflation as a time series"},
|
||||
]
|
||||
|
||||
for input in user_input:
|
||||
response = codex_agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input["prompt"],
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
documents=input.get("documents", None),
|
||||
)
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert "Tool:code_interpreter" in logs_str
|
||||
|
||||
|
||||
def test_custom_tool(llama_stack_client, agent_config):
|
||||
client_tool = TestClientTool()
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"tools": [
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "brave",
|
||||
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
|
||||
},
|
||||
{
|
||||
"function_name": "get_boiling_point",
|
||||
"description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||
"parameters": {
|
||||
"liquid_name": {
|
||||
"param_type": "str",
|
||||
"description": "The name of the liquid",
|
||||
"required": True,
|
||||
},
|
||||
"celcius": {
|
||||
"param_type": "boolean",
|
||||
"description": "Whether to return the boiling point in Celcius",
|
||||
"required": False,
|
||||
},
|
||||
},
|
||||
"type": "function_call",
|
||||
},
|
||||
],
|
||||
"toolgroups": ["builtin::websearch"],
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
"tool_prompt_format": "python_list",
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),))
|
||||
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -267,3 +272,55 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
logs_str = "".join(logs)
|
||||
assert "-100" in logs_str
|
||||
assert "CustomTool" in logs_str
|
||||
|
||||
|
||||
def test_rag_agent(llama_stack_client, agent_config):
|
||||
urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"]
|
||||
documents = [
|
||||
Document(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
metadata={},
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
memory_bank_id = "test-memory-bank"
|
||||
llama_stack_client.memory_banks.register(
|
||||
memory_bank_id=memory_bank_id,
|
||||
params={
|
||||
"memory_bank_type": "vector",
|
||||
"embedding_model": "all-MiniLM-L6-v2",
|
||||
"chunk_size_in_tokens": 512,
|
||||
"overlap_size_in_tokens": 64,
|
||||
},
|
||||
)
|
||||
llama_stack_client.memory.insert(
|
||||
bank_id=memory_bank_id,
|
||||
documents=documents,
|
||||
)
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
dict(
|
||||
name="builtin::memory",
|
||||
args={
|
||||
"memory_bank_ids": [memory_bank_id],
|
||||
},
|
||||
)
|
||||
],
|
||||
}
|
||||
rag_agent = Agent(llama_stack_client, agent_config)
|
||||
session_id = rag_agent.create_session("test-session")
|
||||
user_prompts = [
|
||||
"What are the top 5 topics that were explained? Only list succinct bullet points.",
|
||||
]
|
||||
for prompt in user_prompts:
|
||||
print(f"User> {prompt}")
|
||||
response = rag_agent.create_turn(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
session_id=session_id,
|
||||
)
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert "Tool:query_memory" in logs_str
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue