mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-02 20:40:36 +00:00
agents to use tools api (#673)
# What does this PR do? PR #639 introduced the notion of Tools API and ability to invoke tools through API just as any resource. This PR changes the Agents to start using the Tools API to invoke tools. Major changes include: 1) Ability to specify tool groups with AgentConfig 2) Agent gets the corresponding tool definitions for the specified tools and pass along to the model 3) Attachements are now named as Documents and their behavior is mostly unchanged from user perspective 4) You can specify args that can be injected to a tool call through Agent config. This is especially useful in case of memory tool, where you want the tool to operate on a specific memory bank. 5) You can also register tool groups with args, which lets the agent inject these as well into the tool call. 6) All tests have been migrated to use new tools API and fixtures including client SDK tests 7) Telemetry just works with tools API because of our trace protocol decorator ## Test Plan ``` pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct pytest -s -v -k together llama_stack/providers/tests/tools/test_tools.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py ``` run.yaml: https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994 Notebook: https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
This commit is contained in:
parent
596afc6497
commit
a5c57cd381
116 changed files with 4959 additions and 2778 deletions
|
@ -5,22 +5,17 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTool,
|
||||
AgentTurnResponseEventType,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
Attachment,
|
||||
MemoryToolDefinition,
|
||||
SearchEngineType,
|
||||
SearchToolDefinition,
|
||||
Document,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolChoice,
|
||||
|
@ -35,7 +30,6 @@ from llama_stack.providers.datatypes import Api
|
|||
#
|
||||
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
||||
# -m "meta_reference"
|
||||
|
||||
from .fixtures import pick_inference_model
|
||||
from .utils import create_agent_session
|
||||
|
||||
|
@ -51,7 +45,7 @@ def common_params(inference_model):
|
|||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[],
|
||||
toolgroups=[],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
|
@ -88,73 +82,6 @@ def query_attachment_messages():
|
|||
]
|
||||
|
||||
|
||||
async def create_agent_turn_with_search_tool(
|
||||
agents_stack: Dict[str, object],
|
||||
search_query_messages: List[object],
|
||||
common_params: Dict[str, str],
|
||||
search_tool_definition: SearchToolDefinition,
|
||||
) -> None:
|
||||
"""
|
||||
Create an agent turn with a search tool.
|
||||
|
||||
Args:
|
||||
agents_stack (Dict[str, object]): The agents stack.
|
||||
search_query_messages (List[object]): The search query messages.
|
||||
common_params (Dict[str, str]): The common parameters.
|
||||
search_tool_definition (SearchToolDefinition): The search tool definition.
|
||||
"""
|
||||
|
||||
# Create an agent with the search tool
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"tools": [search_tool_definition],
|
||||
}
|
||||
)
|
||||
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_stack.impls[Api.agents], agent_config
|
||||
)
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=search_query_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk
|
||||
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
|
||||
**turn_request
|
||||
)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||
)
|
||||
|
||||
check_event_types(turn_response)
|
||||
|
||||
# Check for tool execution events
|
||||
tool_execution_events = [
|
||||
chunk
|
||||
for chunk in turn_response
|
||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
|
||||
]
|
||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||
|
||||
# Check the tool execution details
|
||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||
assert isinstance(tool_execution, ToolExecutionStep)
|
||||
assert len(tool_execution.tool_calls) > 0
|
||||
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
|
||||
assert len(tool_execution.tool_responses) > 0
|
||||
|
||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||
|
||||
|
||||
class TestAgents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_turns_with_safety(
|
||||
|
@ -227,7 +154,7 @@ class TestAgents:
|
|||
check_turn_complete_event(turn_response, session_id, sample_messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_agent_as_attachments(
|
||||
async def test_rag_agent(
|
||||
self,
|
||||
agents_stack,
|
||||
attachment_message,
|
||||
|
@ -243,29 +170,17 @@ class TestAgents:
|
|||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
|
||||
attachments = [
|
||||
Attachment(
|
||||
documents = [
|
||||
Document(
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"tools": [
|
||||
MemoryToolDefinition(
|
||||
memory_bank_configs=[],
|
||||
query_generator_config={
|
||||
"type": "default",
|
||||
"sep": " ",
|
||||
},
|
||||
max_tokens_in_context=4096,
|
||||
max_chunks=10,
|
||||
),
|
||||
],
|
||||
"toolgroups": ["builtin::memory"],
|
||||
"tool_choice": ToolChoice.auto,
|
||||
}
|
||||
)
|
||||
|
@ -275,7 +190,7 @@ class TestAgents:
|
|||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=attachment_message,
|
||||
attachments=attachments,
|
||||
documents=documents,
|
||||
stream=True,
|
||||
)
|
||||
turn_response = [
|
||||
|
@ -298,22 +213,6 @@ class TestAgents:
|
|||
|
||||
assert len(turn_response) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_brave_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
):
|
||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
search_tool_definition = SearchToolDefinition(
|
||||
type=AgentTool.brave_search.value,
|
||||
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
|
||||
engine=SearchEngineType.brave,
|
||||
)
|
||||
await create_agent_turn_with_search_tool(
|
||||
agents_stack, search_query_messages, common_params, search_tool_definition
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_tavily_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
|
@ -321,14 +220,57 @@ class TestAgents:
|
|||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
search_tool_definition = SearchToolDefinition(
|
||||
type=AgentTool.brave_search.value, # place holder only
|
||||
api_key=os.environ["TAVILY_SEARCH_API_KEY"],
|
||||
engine=SearchEngineType.tavily,
|
||||
# Create an agent with the toolgroup
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"toolgroups": ["builtin::web_search"],
|
||||
}
|
||||
)
|
||||
await create_agent_turn_with_search_tool(
|
||||
agents_stack, search_query_messages, common_params, search_tool_definition
|
||||
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_stack.impls[Api.agents], agent_config
|
||||
)
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=search_query_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk
|
||||
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
|
||||
**turn_request
|
||||
)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||
)
|
||||
|
||||
check_event_types(turn_response)
|
||||
|
||||
# Check for tool execution events
|
||||
tool_execution_events = [
|
||||
chunk
|
||||
for chunk in turn_response
|
||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||
and chunk.event.payload.step_details.step_type
|
||||
== StepType.tool_execution.value
|
||||
]
|
||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||
|
||||
# Check the tool execution details
|
||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||
assert isinstance(tool_execution, ToolExecutionStep)
|
||||
assert len(tool_execution.tool_calls) > 0
|
||||
actual_tool_name = tool_execution.tool_calls[0].tool_name
|
||||
assert actual_tool_name == BuiltinTool.brave_search
|
||||
assert len(tool_execution.tool_responses) > 0
|
||||
|
||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||
|
||||
|
||||
def check_event_types(turn_response):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue