agents to use tools api (#673)

# What does this PR do?

PR #639 introduced the notion of Tools API and ability to invoke tools
through API just as any resource. This PR changes the Agents to start
using the Tools API to invoke tools. Major changes include:
1) Ability to specify tool groups with AgentConfig
2) Agent gets the corresponding tool definitions for the specified tools
and pass along to the model
3) Attachements are now named as Documents and their behavior is mostly
unchanged from user perspective
4) You can specify args that can be injected to a tool call through
Agent config. This is especially useful in case of memory tool, where
you want the tool to operate on a specific memory bank.
5) You can also register tool groups with args, which lets the agent
inject these as well into the tool call.
6) All tests have been migrated to use new tools API and fixtures
including client SDK tests
7) Telemetry just works with tools API because of our trace protocol
decorator


## Test Plan
```
pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py  \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct

pytest -s -v -k together  llama_stack/providers/tests/tools/test_tools.py \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct

LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py
```
run.yaml:
https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994

Notebook:
https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
This commit is contained in:
Dinesh Yeduguru 2025-01-08 19:01:00 -08:00 committed by GitHub
parent 596afc6497
commit a5c57cd381
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
116 changed files with 4959 additions and 2778 deletions

View file

@ -9,24 +9,21 @@ from typing import Dict, List
from uuid import uuid4
import pytest
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.custom_tool import CustomTool
from llama_stack_client.lib.agents.client_tool import ClientTool
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import CompletionMessage, ToolResponseMessage
from llama_stack_client.types import ToolResponseMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.tool_param_definition_param import (
ToolParamDefinitionParam,
)
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
from llama_stack_client.types.memory_insert_params import Document
from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.tool_def_param import Parameter
class TestCustomTool(CustomTool):
class TestClientTool(ClientTool):
"""Tool to give boiling point of a liquid
Returns the correct value for water in Celcius and Fahrenheit
Returns the correct value for polyjuice in Celcius and Fahrenheit
and returns -1 for other liquids
"""
def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
@ -54,15 +51,19 @@ class TestCustomTool(CustomTool):
return "get_boiling_point"
def get_description(self) -> str:
return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:
def get_params_definition(self) -> Dict[str, Parameter]:
return {
"liquid_name": ToolParamDefinitionParam(
param_type="string", description="The name of the liquid", required=True
"liquid_name": Parameter(
name="liquid_name",
parameter_type="string",
description="The name of the liquid",
required=True,
),
"celcius": ToolParamDefinitionParam(
param_type="boolean",
"celcius": Parameter(
name="celcius",
parameter_type="boolean",
description="Whether to return the boiling point in Celcius",
required=False,
),
@ -100,7 +101,7 @@ def agent_config(llama_stack_client):
"temperature": 1.0,
"top_p": 0.9,
},
tools=[],
toolgroups=[],
tool_choice="auto",
tool_prompt_format="json",
input_shields=available_shields,
@ -148,18 +149,13 @@ def test_agent_simple(llama_stack_client, agent_config):
assert "I can't" in logs_str
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
def test_builtin_tool_web_search(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"tools": [
{
"type": "brave_search",
"engine": "brave",
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
}
"toolgroups": [
"builtin::websearch",
],
}
print(f"Agent Config: {agent_config}")
agent = Agent(llama_stack_client, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
@ -167,7 +163,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
messages=[
{
"role": "user",
"content": "Search the web and tell me who the 44th president of the United States was. Please use tools",
"content": "Search the web and tell me who the current CEO of Meta is.",
}
],
session_id=session_id,
@ -178,18 +174,15 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
assert "tool_execution>" in logs_str
assert "Tool:brave_search Response:" in logs_str
assert "obama" in logs_str.lower()
if len(agent_config["input_shields"]) > 0:
assert "No Violation" in logs_str
assert "mark zuckerberg" in logs_str.lower()
assert "No Violation" in logs_str
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"tools": [
{
"type": "code_interpreter",
}
"toolgroups": [
"builtin::code_interpreter",
],
}
agent = Agent(llama_stack_client, agent_config)
@ -199,7 +192,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
messages=[
{
"role": "user",
"content": "Write code to answer the question: What is the 100th prime number?",
"content": "Write code and execute it to find the answer for: What is the 100th prime number?",
},
],
session_id=session_id,
@ -207,50 +200,62 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
if "Tool:code_interpreter Response" not in logs_str:
assert len(logs_str) > 0
pytest.skip("code_interpreter not called by model")
assert "541" in logs_str
assert "Tool:code_interpreter Response" in logs_str
if "No such file or directory: 'bwrap'" in logs_str:
assert "prime" in logs_str
pytest.skip("`bwrap` is not available on this platform")
else:
assert "541" in logs_str
def test_code_execution(llama_stack_client):
agent_config = AgentConfig(
model="meta-llama/Llama-3.1-8B-Instruct",
instructions="You are a helpful assistant",
toolgroups=[
"builtin::code_interpreter",
],
tool_choice="required",
input_shields=[],
output_shields=[],
enable_session_persistence=False,
)
codex_agent = Agent(llama_stack_client, agent_config)
session_id = codex_agent.create_session("test-session")
inflation_doc = AgentDocument(
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
mime_type="text/csv",
)
user_input = [
{"prompt": "Here is a csv, can you describe it?", "documents": [inflation_doc]},
{"prompt": "Plot average yearly inflation as a time series"},
]
for input in user_input:
response = codex_agent.create_turn(
messages=[
{
"role": "user",
"content": input["prompt"],
}
],
session_id=session_id,
documents=input.get("documents", None),
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:code_interpreter" in logs_str
def test_custom_tool(llama_stack_client, agent_config):
client_tool = TestClientTool()
agent_config = {
**agent_config,
"model": "meta-llama/Llama-3.2-3B-Instruct",
"tools": [
{
"type": "brave_search",
"engine": "brave",
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
},
{
"function_name": "get_boiling_point",
"description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
"parameters": {
"liquid_name": {
"param_type": "str",
"description": "The name of the liquid",
"required": True,
},
"celcius": {
"param_type": "boolean",
"description": "Whether to return the boiling point in Celcius",
"required": False,
},
},
"type": "function_call",
},
],
"toolgroups": ["builtin::websearch"],
"client_tools": [client_tool.get_tool_definition()],
"tool_prompt_format": "python_list",
}
agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),))
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -267,3 +272,55 @@ def test_custom_tool(llama_stack_client, agent_config):
logs_str = "".join(logs)
assert "-100" in logs_str
assert "CustomTool" in logs_str
def test_rag_agent(llama_stack_client, agent_config):
urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
memory_bank_id = "test-memory-bank"
llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id,
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
)
llama_stack_client.memory.insert(
bank_id=memory_bank_id,
documents=documents,
)
agent_config = {
**agent_config,
"toolgroups": [
dict(
name="builtin::memory",
args={
"memory_bank_ids": [memory_bank_id],
},
)
],
}
rag_agent = Agent(llama_stack_client, agent_config)
session_id = rag_agent.create_session("test-session")
user_prompts = [
"What are the top 5 topics that were explained? Only list succinct bullet points.",
]
for prompt in user_prompts:
print(f"User> {prompt}")
response = rag_agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:query_memory" in logs_str