mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
added tool calling test
This commit is contained in:
parent
ef4b74c935
commit
77a486f176
1 changed files with 101 additions and 0 deletions
|
@ -4,11 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
|
@ -26,6 +31,8 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
# --tb=short --disable-warnings
|
# --tb=short --disable-warnings
|
||||||
# ```
|
# ```
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def agents_settings():
|
async def agents_settings():
|
||||||
|
@ -50,6 +57,13 @@ def sample_messages():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def search_query_messages():
|
||||||
|
return [
|
||||||
|
UserMessage(content="What are the latest developments in quantum computing?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_agent_turn(agents_settings, sample_messages):
|
async def test_create_agent_turn(agents_settings, sample_messages):
|
||||||
agents_impl = agents_settings["impl"]
|
agents_impl = agents_settings["impl"]
|
||||||
|
@ -107,3 +121,90 @@ async def test_create_agent_turn(agents_settings, sample_messages):
|
||||||
assert final_event.turn.input_messages == sample_messages
|
assert final_event.turn.input_messages == sample_messages
|
||||||
assert isinstance(final_event.turn.output_message, CompletionMessage)
|
assert isinstance(final_event.turn.output_message, CompletionMessage)
|
||||||
assert len(final_event.turn.output_message.content) > 0
|
assert len(final_event.turn.output_message.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_turn_with_brave_search(
|
||||||
|
agents_settings, search_query_messages
|
||||||
|
):
|
||||||
|
agents_impl = agents_settings["impl"]
|
||||||
|
|
||||||
|
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||||
|
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||||
|
|
||||||
|
# Create an agent with Brave search tool
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model=agents_settings["common_params"]["model"],
|
||||||
|
instructions=agents_settings["common_params"]["instructions"],
|
||||||
|
enable_session_persistence=True,
|
||||||
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||||
|
input_shields=[],
|
||||||
|
output_shields=[],
|
||||||
|
tools=[
|
||||||
|
SearchToolDefinition(
|
||||||
|
type=AgentTool.brave_search.value,
|
||||||
|
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
|
||||||
|
engine=SearchEngineType.brave,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tool_choice=ToolChoice.auto,
|
||||||
|
max_infer_iters=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
create_response = await agents_impl.create_agent(agent_config)
|
||||||
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
session_create_response = await agents_impl.create_agent_session(
|
||||||
|
agent_id, "Test Session with Brave Search"
|
||||||
|
)
|
||||||
|
session_id = session_create_response.session_id
|
||||||
|
|
||||||
|
# Create and execute a turn
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=search_query_messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_response = [
|
||||||
|
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(turn_response) > 0
|
||||||
|
assert all(
|
||||||
|
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for expected event types
|
||||||
|
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
||||||
|
assert AgentTurnResponseEventType.turn_start.value in event_types
|
||||||
|
assert AgentTurnResponseEventType.step_start.value in event_types
|
||||||
|
assert AgentTurnResponseEventType.step_complete.value in event_types
|
||||||
|
assert AgentTurnResponseEventType.turn_complete.value in event_types
|
||||||
|
|
||||||
|
# Check for tool execution events
|
||||||
|
tool_execution_events = [
|
||||||
|
chunk
|
||||||
|
for chunk in turn_response
|
||||||
|
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||||
|
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
|
||||||
|
]
|
||||||
|
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||||
|
|
||||||
|
# Check the tool execution details
|
||||||
|
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||||
|
assert isinstance(tool_execution, ToolExecutionStep)
|
||||||
|
assert len(tool_execution.tool_calls) > 0
|
||||||
|
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
|
||||||
|
assert len(tool_execution.tool_responses) > 0
|
||||||
|
|
||||||
|
# Check the final turn complete event
|
||||||
|
final_event = turn_response[-1].event.payload
|
||||||
|
assert isinstance(final_event, AgentTurnResponseTurnCompletePayload)
|
||||||
|
assert isinstance(final_event.turn, Turn)
|
||||||
|
assert final_event.turn.session_id == session_id
|
||||||
|
assert final_event.turn.input_messages == search_query_messages
|
||||||
|
assert isinstance(final_event.turn.output_message, CompletionMessage)
|
||||||
|
assert len(final_event.turn.output_message.content) > 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue