agents to use tools api (#673)

# What does this PR do?

PR #639 introduced the notion of Tools API and ability to invoke tools
through API just as any resource. This PR changes the Agents to start
using the Tools API to invoke tools. Major changes include:
1) Ability to specify tool groups with AgentConfig
2) Agent gets the corresponding tool definitions for the specified tools
and pass along to the model
3) Attachements are now named as Documents and their behavior is mostly
unchanged from user perspective
4) You can specify args that can be injected to a tool call through
Agent config. This is especially useful in case of memory tool, where
you want the tool to operate on a specific memory bank.
5) You can also register tool groups with args, which lets the agent
inject these as well into the tool call.
6) All tests have been migrated to use new tools API and fixtures
including client SDK tests
7) Telemetry just works with tools API because of our trace protocol
decorator


## Test Plan
```
pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py  \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct

pytest -s -v -k together  llama_stack/providers/tests/tools/test_tools.py \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct

LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py
```
run.yaml:
https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994

Notebook:
https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
This commit is contained in:
Dinesh Yeduguru 2025-01-08 19:01:00 -08:00 committed by GitHub
parent 596afc6497
commit a5c57cd381
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
116 changed files with 4959 additions and 2778 deletions

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,93 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import unittest
from llama_models.llama3.api.datatypes import (
Attachment,
BuiltinTool,
CompletionMessage,
StopReason,
ToolCall,
)
from ..tools.builtin import CodeInterpreterTool
class TestCodeInterpreter(unittest.IsolatedAsyncioTestCase):
async def test_matplotlib(self):
tool = CodeInterpreterTool()
code = """
import matplotlib.pyplot as plt
import numpy as np
x = np.array([1, 1])
y = np.array([0, 10])
plt.plot(x, y)
plt.title('x = 1')
plt.xlabel('x')
plt.ylabel('y')
plt.grid(True)
plt.axvline(x=1, color='r')
plt.show()
"""
message = CompletionMessage(
role="assistant",
content="",
tool_calls=[
ToolCall(
call_id="call_id",
tool_name=BuiltinTool.code_interpreter,
arguments={"code": code},
)
],
stop_reason=StopReason.end_of_message,
)
ret = await tool.run([message])
self.assertEqual(len(ret), 1)
output = ret[0].content
self.assertIsInstance(output, Attachment)
self.assertEqual(output.mime_type, "image/png")
async def test_path_unlink(self):
tool = CodeInterpreterTool()
code = """
import os
from pathlib import Path
import tempfile
dpath = Path(os.environ["MPLCONFIGDIR"])
with open(dpath / "test", "w") as f:
f.write("hello")
Path(dpath / "test").unlink()
print("_OK_")
"""
message = CompletionMessage(
role="assistant",
content="",
tool_calls=[
ToolCall(
call_id="call_id",
tool_name=BuiltinTool.code_interpreter,
arguments={"code": code},
)
],
stop_reason=StopReason.end_of_message,
)
ret = await tool.run([message])
self.assertEqual(len(ret), 1)
output = ret[0].content
self.assertTrue("_OK_" in output)
if __name__ == "__main__":
unittest.main()

View file

@ -4,21 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from typing import AsyncIterator, List, Optional, Union
import pytest
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import (
AgentConfig,
AgentToolGroupWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseTurnCompletePayload,
StepType,
)
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseStreamChunk,
CompletionMessage,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
@ -27,13 +32,24 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.memory_banks import BankParams, VectorMemoryBank
from llama_stack.apis.safety import RunShieldResponse
from ..agents import (
AGENT_INSTANCES_BY_ID,
MetaReferenceAgentsImpl,
MetaReferenceInferenceConfig,
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolGroup,
ToolHost,
ToolInvocationResult,
ToolPromptFormat,
)
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
from llama_stack.providers.inline.agents.meta_reference.agents import (
MetaReferenceAgentsImpl,
MetaReferenceAgentsImplConfig,
)
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class MockInferenceAPI:
@ -48,10 +64,10 @@ class MockInferenceAPI:
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
if stream:
async def stream_response():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="start",
@ -65,19 +81,7 @@ class MockInferenceAPI:
delta="AI is a fascinating field...",
)
)
# yield ChatCompletionResponseStreamChunk(
# event=ChatCompletionResponseEvent(
# event_type="progress",
# delta=ToolCallDelta(
# content=ToolCall(
# call_id="123",
# tool_name=BuiltinTool.brave_search.value,
# arguments={"query": "AI history"},
# ),
# parse_status="success",
# ),
# )
# )
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="complete",
@ -85,12 +89,17 @@ class MockInferenceAPI:
stop_reason="end_of_turn",
)
)
if stream:
return stream_response()
else:
yield ChatCompletionResponse(
return ChatCompletionResponse(
completion_message=CompletionMessage(
role="assistant", content="Mock response", stop_reason="end_of_turn"
role="assistant",
content="Mock response",
stop_reason="end_of_turn",
),
logprobs=[0.1, 0.2, 0.3] if logprobs else None,
logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None,
)
@ -165,6 +174,98 @@ class MockMemoryAPI:
self.documents[bank_id].pop(doc_id, None)
class MockToolGroupsAPI:
async def register_tool_group(
self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None
) -> None:
pass
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return ToolGroup(
identifier=toolgroup_id,
provider_resource_id=toolgroup_id,
)
async def list_tool_groups(self) -> List[ToolGroup]:
return []
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
if tool_group_id == MEMORY_TOOLGROUP:
return [
Tool(
identifier=MEMORY_QUERY_TOOL,
provider_resource_id=MEMORY_QUERY_TOOL,
toolgroup_id=MEMORY_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::memory",
parameters=[],
)
]
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
return [
Tool(
identifier="code_interpreter",
provider_resource_id="code_interpreter",
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::code_interpreter",
parameters=[],
)
]
return []
async def get_tool(self, tool_name: str) -> Tool:
return Tool(
identifier=tool_name,
provider_resource_id=tool_name,
toolgroup_id="mock_group",
tool_host=ToolHost.client,
description="Mock tool",
provider_id="mock_provider",
parameters=[],
)
async def unregister_tool_group(self, tool_group_id: str) -> None:
pass
class MockToolRuntimeAPI:
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return []
async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
return ToolInvocationResult(content={"result": "Mock tool result"})
class MockMemoryBanksAPI:
async def list_memory_banks(self) -> List[MemoryBank]:
return []
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return None
async def register_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank:
return VectorMemoryBank(
identifier=memory_bank_id,
provider_resource_id=provider_memory_bank_id or memory_bank_id,
embedding_model="mock_model",
chunk_size_in_tokens=512,
)
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
pass
@pytest.fixture
def mock_inference_api():
return MockInferenceAPI()
@ -181,64 +282,107 @@ def mock_memory_api():
@pytest.fixture
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
def mock_tool_groups_api():
return MockToolGroupsAPI()
@pytest.fixture
def mock_tool_runtime_api():
return MockToolRuntimeAPI()
@pytest.fixture
def mock_memory_banks_api():
return MockMemoryBanksAPI()
@pytest.fixture
async def get_agents_impl(
mock_inference_api,
mock_safety_api,
mock_memory_api,
mock_memory_banks_api,
mock_tool_runtime_api,
mock_tool_groups_api,
):
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
impl = MetaReferenceAgentsImpl(
config=MetaReferenceInferenceConfig(),
config=MetaReferenceAgentsImplConfig(
persistence_store=SqliteKVStoreConfig(
db_name=sqlite_file.name,
),
),
inference_api=mock_inference_api,
safety_api=mock_safety_api,
memory_api=mock_memory_api,
memory_banks_api=mock_memory_banks_api,
tool_runtime_api=mock_tool_runtime_api,
tool_groups_api=mock_tool_groups_api,
)
await impl.initialize()
return impl
@pytest.fixture
async def get_chat_agent(get_agents_impl):
impl = await get_agents_impl
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
sampling_params=SamplingParams(),
tools=[
# SearchToolDefinition(
# name="brave_search",
# api_key="test_key",
# ),
],
toolgroups=[],
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=[],
output_shields=[],
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
agent = AGENT_INSTANCES_BY_ID[response.agent_id]
return agent
return await impl.get_agent(response.agent_id)
MEMORY_TOOLGROUP = "builtin::memory"
CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter"
@pytest.fixture
async def get_chat_agent_with_tools(get_agents_impl, request):
impl = await get_agents_impl
toolgroups = request.param
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
toolgroups=toolgroups,
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
return await impl.get_agent(response.agent_id)
@pytest.mark.asyncio
async def test_chat_agent_create_session(chat_agent):
session = chat_agent.create_session("Test Session")
assert session.session_name == "Test Session"
assert session.turns == []
assert session.session_id in chat_agent.sessions
@pytest.mark.asyncio
async def test_chat_agent_create_and_execute_turn(chat_agent):
session = chat_agent.create_session("Test Session")
async def test_chat_agent_create_and_execute_turn(get_chat_agent):
chat_agent = await get_chat_agent
session_id = await chat_agent.create_session("Test Session")
request = AgentTurnCreateRequest(
agent_id="random",
session_id=session.session_id,
agent_id=chat_agent.agent_id,
session_id=session_id,
messages=[UserMessage(content="Hello")],
stream=True,
)
responses = []
async for response in chat_agent.create_and_execute_turn(request):
responses.append(response)
print(responses)
assert len(responses) > 0
assert len(responses) == 4 # TurnStart, StepStart, StepComplete, TurnComplete
assert (
len(responses) == 7
) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete
assert responses[0].event.payload.turn_id is not None
@pytest.mark.asyncio
async def test_run_multiple_shields_wrapper(chat_agent):
async def test_run_multiple_shields_wrapper(get_chat_agent):
chat_agent = await get_chat_agent
messages = [UserMessage(content="Test message")]
shields = ["test_shield"]
@ -254,69 +398,95 @@ async def test_run_multiple_shields_wrapper(chat_agent):
assert len(responses) == 2 # StepStart, StepComplete
assert responses[0].event.payload.step_type.value == "shield_call"
assert not responses[1].event.payload.step_details.response.is_violation
assert not responses[1].event.payload.step_details.violation
@pytest.mark.asyncio
@pytest.mark.skip(reason="Not yet implemented; need to mock out tool execution easily")
async def test_chat_agent_complex_turn(chat_agent):
# Setup
session = chat_agent.create_session("Test Session")
async def test_chat_agent_complex_turn(get_chat_agent):
chat_agent = await get_chat_agent
session_id = await chat_agent.create_session("Test Session")
request = AgentTurnCreateRequest(
agent_id="random",
session_id=session.session_id,
agent_id=chat_agent.agent_id,
session_id=session_id,
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
stream=True,
)
# Execute the turn
responses = []
async for response in chat_agent.create_and_execute_turn(request):
responses.append(response)
# Assertions
assert len(responses) > 0
# Check for the presence of different step types
step_types = [
response.event.payload.step_type
for response in responses
if hasattr(response.event.payload, "step_type")
]
assert "shield_call" in step_types, "Shield call step is missing"
assert "inference" in step_types, "Inference step is missing"
assert "tool_execution" in step_types, "Tool execution step is missing"
assert StepType.shield_call in step_types, "Shield call step is missing"
assert StepType.inference in step_types, "Inference step is missing"
# Check for the presence of start and complete events
event_types = [
response.event.payload.event_type
for response in responses
if hasattr(response.event.payload, "event_type")
]
assert "start" in event_types, "Start event is missing"
assert "complete" in event_types, "Complete event is missing"
assert "turn_start" in event_types, "Start event is missing"
assert "turn_complete" in event_types, "Complete event is missing"
# Check for the presence of tool call
tool_calls = [
response.event.payload.tool_call
for response in responses
if hasattr(response.event.payload, "tool_call")
]
assert any(
tool_call
for tool_call in tool_calls
if tool_call and tool_call.content.get("name") == "memory"
), "Memory tool call is missing"
# Check for the final turn complete event
assert any(
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
for response in responses
), "Turn complete event is missing"
turn_complete_payload = next(
response.event.payload
for response in responses
if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
)
turn = turn_complete_payload.turn
assert turn.input_messages == request.messages, "Input messages do not match"
# Verify the turn was added to the session
assert len(session.turns) == 1, "Turn was not added to the session"
assert (
session.turns[0].input_messages == request.messages
), "Input messages do not match"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"toolgroups, expected_memory, expected_code_interpreter",
[
([], False, False), # no tools
([MEMORY_TOOLGROUP], True, False), # memory only
([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
],
)
async def test_chat_agent_tools(
get_agents_impl, toolgroups, expected_memory, expected_code_interpreter
):
impl = await get_agents_impl
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
toolgroups=toolgroups,
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
chat_agent = await impl.get_agent(response.agent_id)
tool_defs, _ = await chat_agent._get_tool_defs()
if expected_memory:
assert MEMORY_QUERY_TOOL in tool_defs
if expected_code_interpreter:
assert BuiltinTool.code_interpreter in tool_defs
if expected_memory and expected_code_interpreter:
# override the tools for turn
new_tool_defs, _ = await chat_agent._get_tool_defs(
toolgroups_for_turn=[
AgentToolGroupWithArgs(
name=MEMORY_TOOLGROUP,
args={"memory_banks": ["test_memory_bank"]},
)
]
)
assert MEMORY_QUERY_TOOL in new_tool_defs
assert BuiltinTool.code_interpreter not in new_tool_defs