mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
Get agents tests working
This commit is contained in:
parent
62dd3b376c
commit
66b658dcce
8 changed files with 352 additions and 269 deletions
71
llama_stack/providers/tests/agents/conftest.py
Normal file
71
llama_stack/providers/tests/agents/conftest.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
|
||||||
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
|
from ..memory.fixtures import MEMORY_FIXTURES
|
||||||
|
from ..safety.fixtures import SAFETY_FIXTURES
|
||||||
|
from .fixtures import AGENTS_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "meta_reference",
|
||||||
|
"safety": "meta_reference",
|
||||||
|
"memory": "meta_reference",
|
||||||
|
"agents": "meta_reference",
|
||||||
|
},
|
||||||
|
id="meta_reference",
|
||||||
|
marks=pytest.mark.meta_reference,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "ollama",
|
||||||
|
"safety": "meta_reference",
|
||||||
|
"memory": "meta_reference",
|
||||||
|
"agents": "meta_reference",
|
||||||
|
},
|
||||||
|
id="ollama",
|
||||||
|
marks=pytest.mark.ollama,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "together",
|
||||||
|
"safety": "together",
|
||||||
|
"memory": "meta_reference",
|
||||||
|
"agents": "meta_reference",
|
||||||
|
},
|
||||||
|
id="together",
|
||||||
|
marks=pytest.mark.together,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
for mark in ["meta_reference", "ollama", "together"]:
|
||||||
|
config.addinivalue_line(
|
||||||
|
"markers",
|
||||||
|
f"{mark}: marks tests as {mark} specific",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "agents_stack" in metafunc.fixturenames:
|
||||||
|
available_fixtures = {
|
||||||
|
"inference": INFERENCE_FIXTURES,
|
||||||
|
"safety": SAFETY_FIXTURES,
|
||||||
|
"memory": MEMORY_FIXTURES,
|
||||||
|
"agents": AGENTS_FIXTURES,
|
||||||
|
}
|
||||||
|
combinations = (
|
||||||
|
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
|
metafunc.parametrize("agents_stack", combinations, indirect=True)
|
61
llama_stack/providers/tests/agents/fixtures.py
Normal file
61
llama_stack/providers/tests/agents/fixtures.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
|
from llama_stack.providers.impls.meta_reference.agents import (
|
||||||
|
MetaReferenceAgentsImplConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
from ..conftest import ProviderFixture
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def agents_meta_reference(inference_model, safety_model) -> ProviderFixture:
|
||||||
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
|
return ProviderFixture(
|
||||||
|
provider=Provider(
|
||||||
|
provider_id="meta-reference",
|
||||||
|
provider_type="meta-reference",
|
||||||
|
config=MetaReferenceAgentsImplConfig(
|
||||||
|
# TODO: make this an in-memory store
|
||||||
|
persistence_store=SqliteKVStoreConfig(
|
||||||
|
db_path=sqlite_file.name,
|
||||||
|
),
|
||||||
|
).model_dump(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
AGENTS_FIXTURES = ["meta_reference"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def agents_stack(inference_model, safety_model, request):
|
||||||
|
fixture_dict = request.param
|
||||||
|
|
||||||
|
providers = {}
|
||||||
|
provider_data = {}
|
||||||
|
for key in ["agents", "inference", "safety", "memory"]:
|
||||||
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
|
providers[key] = [fixture.provider]
|
||||||
|
if fixture.provider_data:
|
||||||
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
|
impls = await resolve_impls_for_test_v2(
|
||||||
|
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||||
|
providers,
|
||||||
|
provider_data,
|
||||||
|
)
|
||||||
|
return impls[Api.agents], impls[Api.memory]
|
|
@ -1,34 +0,0 @@
|
||||||
providers:
|
|
||||||
inference:
|
|
||||||
- provider_id: together
|
|
||||||
provider_type: remote::together
|
|
||||||
config: {}
|
|
||||||
- provider_id: tgi
|
|
||||||
provider_type: remote::tgi
|
|
||||||
config:
|
|
||||||
url: http://127.0.0.1:7001
|
|
||||||
# - provider_id: meta-reference
|
|
||||||
# provider_type: meta-reference
|
|
||||||
# config:
|
|
||||||
# model: Llama-Guard-3-1B
|
|
||||||
# - provider_id: remote
|
|
||||||
# provider_type: remote
|
|
||||||
# config:
|
|
||||||
# host: localhost
|
|
||||||
# port: 7010
|
|
||||||
safety:
|
|
||||||
- provider_id: together
|
|
||||||
provider_type: remote::together
|
|
||||||
config: {}
|
|
||||||
memory:
|
|
||||||
- provider_id: faiss
|
|
||||||
provider_type: meta-reference
|
|
||||||
config: {}
|
|
||||||
agents:
|
|
||||||
- provider_id: meta-reference
|
|
||||||
provider_type: meta-reference
|
|
||||||
config:
|
|
||||||
persistence_store:
|
|
||||||
namespace: null
|
|
||||||
type: sqlite
|
|
||||||
db_path: ~/.llama/runtime/kvstore.db
|
|
|
@ -7,47 +7,20 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
|
||||||
from llama_stack.providers.datatypes import * # noqa: F403
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
# 1. Ensure you have a conda environment with the right dependencies installed.
|
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
||||||
# This includes `pytest` and `pytest-asyncio`.
|
# -m "meta_reference"
|
||||||
#
|
|
||||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
|
||||||
#
|
|
||||||
# 3. Run:
|
|
||||||
#
|
|
||||||
# ```bash
|
|
||||||
# PROVIDER_ID=<your_provider> \
|
|
||||||
# MODEL_ID=<your_model> \
|
|
||||||
# PROVIDER_CONFIG=provider_config.yaml \
|
|
||||||
# pytest -s llama_stack/providers/tests/agents/test_agents.py \
|
|
||||||
# --tb=short --disable-warnings
|
|
||||||
# ```
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest.fixture
|
||||||
async def agents_settings():
|
def common_params():
|
||||||
impls = await resolve_impls_for_test(
|
|
||||||
Api.agents, deps=[Api.inference, Api.memory, Api.safety]
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"impl": impls[Api.agents],
|
"instructions": "You are a helpful assistant.",
|
||||||
"memory_impl": impls[Api.memory],
|
|
||||||
"common_params": {
|
|
||||||
"model": os.environ["MODEL_ID"] or "Llama3.1-8B-Instruct",
|
|
||||||
"instructions": "You are a helpful assistant.",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,230 +56,237 @@ def query_attachment_messages():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.parametrize(
|
||||||
async def test_create_agent_turn(agents_settings, sample_messages):
|
"inference_model", [pytest.param("Llama3.1-8B-Instruct", id="")], indirect=True
|
||||||
agents_impl = agents_settings["impl"]
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"safety_model", [pytest.param("Llama-Guard-3-8B", id="")], indirect=True
|
||||||
|
)
|
||||||
|
class TestAgents:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_turn(
|
||||||
|
self, agents_stack, sample_messages, common_params, inference_model
|
||||||
|
):
|
||||||
|
agents_impl, _ = agents_stack
|
||||||
|
|
||||||
# First, create an agent
|
# First, create an agent
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model=agents_settings["common_params"]["model"],
|
model=inference_model,
|
||||||
instructions=agents_settings["common_params"]["instructions"],
|
instructions=common_params["instructions"],
|
||||||
enable_session_persistence=True,
|
enable_session_persistence=True,
|
||||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||||
input_shields=[],
|
input_shields=[],
|
||||||
output_shields=[],
|
output_shields=[],
|
||||||
tools=[],
|
tools=[],
|
||||||
max_infer_iters=5,
|
max_infer_iters=5,
|
||||||
)
|
|
||||||
|
|
||||||
create_response = await agents_impl.create_agent(agent_config)
|
|
||||||
agent_id = create_response.agent_id
|
|
||||||
|
|
||||||
# Create a session
|
|
||||||
session_create_response = await agents_impl.create_agent_session(
|
|
||||||
agent_id, "Test Session"
|
|
||||||
)
|
|
||||||
session_id = session_create_response.session_id
|
|
||||||
|
|
||||||
# Create and execute a turn
|
|
||||||
turn_request = dict(
|
|
||||||
agent_id=agent_id,
|
|
||||||
session_id=session_id,
|
|
||||||
messages=sample_messages,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
turn_response = [
|
|
||||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
|
||||||
]
|
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
|
||||||
assert all(
|
|
||||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for expected event types
|
|
||||||
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
|
||||||
assert AgentTurnResponseEventType.turn_start.value in event_types
|
|
||||||
assert AgentTurnResponseEventType.step_start.value in event_types
|
|
||||||
assert AgentTurnResponseEventType.step_complete.value in event_types
|
|
||||||
assert AgentTurnResponseEventType.turn_complete.value in event_types
|
|
||||||
|
|
||||||
# Check the final turn complete event
|
|
||||||
final_event = turn_response[-1].event.payload
|
|
||||||
assert isinstance(final_event, AgentTurnResponseTurnCompletePayload)
|
|
||||||
assert isinstance(final_event.turn, Turn)
|
|
||||||
assert final_event.turn.session_id == session_id
|
|
||||||
assert final_event.turn.input_messages == sample_messages
|
|
||||||
assert isinstance(final_event.turn.output_message, CompletionMessage)
|
|
||||||
assert len(final_event.turn.output_message.content) > 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_rag_agent_as_attachments(
|
|
||||||
agents_settings, attachment_message, query_attachment_messages
|
|
||||||
):
|
|
||||||
urls = [
|
|
||||||
"memory_optimizations.rst",
|
|
||||||
"chat.rst",
|
|
||||||
"llama3.rst",
|
|
||||||
"datasets.rst",
|
|
||||||
"qat_finetune.rst",
|
|
||||||
"lora_finetune.rst",
|
|
||||||
]
|
|
||||||
|
|
||||||
attachments = [
|
|
||||||
Attachment(
|
|
||||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
|
||||||
mime_type="text/plain",
|
|
||||||
)
|
)
|
||||||
for i, url in enumerate(urls)
|
|
||||||
]
|
|
||||||
|
|
||||||
agents_impl = agents_settings["impl"]
|
create_response = await agents_impl.create_agent(agent_config)
|
||||||
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
# Create a session
|
||||||
model=agents_settings["common_params"]["model"],
|
session_create_response = await agents_impl.create_agent_session(
|
||||||
instructions=agents_settings["common_params"]["instructions"],
|
agent_id, "Test Session"
|
||||||
enable_session_persistence=True,
|
)
|
||||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
session_id = session_create_response.session_id
|
||||||
input_shields=[],
|
|
||||||
output_shields=[],
|
|
||||||
tools=[
|
|
||||||
MemoryToolDefinition(
|
|
||||||
memory_bank_configs=[],
|
|
||||||
query_generator_config={
|
|
||||||
"type": "default",
|
|
||||||
"sep": " ",
|
|
||||||
},
|
|
||||||
max_tokens_in_context=4096,
|
|
||||||
max_chunks=10,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
max_infer_iters=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
create_response = await agents_impl.create_agent(agent_config)
|
# Create and execute a turn
|
||||||
agent_id = create_response.agent_id
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=sample_messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Create a session
|
turn_response = [
|
||||||
session_create_response = await agents_impl.create_agent_session(
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
agent_id, "Test Session"
|
]
|
||||||
)
|
|
||||||
session_id = session_create_response.session_id
|
|
||||||
|
|
||||||
# Create and execute a turn
|
assert len(turn_response) > 0
|
||||||
turn_request = dict(
|
assert all(
|
||||||
agent_id=agent_id,
|
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||||
session_id=session_id,
|
)
|
||||||
messages=attachment_message,
|
|
||||||
attachments=attachments,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
turn_response = [
|
check_event_types(turn_response)
|
||||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
check_turn_complete_event(turn_response, session_id, sample_messages)
|
||||||
]
|
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
@pytest.mark.asyncio
|
||||||
|
async def test_rag_agent_as_attachments(
|
||||||
|
self,
|
||||||
|
agents_stack,
|
||||||
|
attachment_message,
|
||||||
|
query_attachment_messages,
|
||||||
|
inference_model,
|
||||||
|
common_params,
|
||||||
|
):
|
||||||
|
agents_impl, _ = agents_stack
|
||||||
|
urls = [
|
||||||
|
"memory_optimizations.rst",
|
||||||
|
"chat.rst",
|
||||||
|
"llama3.rst",
|
||||||
|
"datasets.rst",
|
||||||
|
"qat_finetune.rst",
|
||||||
|
"lora_finetune.rst",
|
||||||
|
]
|
||||||
|
|
||||||
# Create a second turn querying the agent
|
attachments = [
|
||||||
turn_request = dict(
|
Attachment(
|
||||||
agent_id=agent_id,
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
session_id=session_id,
|
mime_type="text/plain",
|
||||||
messages=query_attachment_messages,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
turn_response = [
|
|
||||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
|
||||||
]
|
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_agent_turn_with_brave_search(
|
|
||||||
agents_settings, search_query_messages
|
|
||||||
):
|
|
||||||
agents_impl = agents_settings["impl"]
|
|
||||||
|
|
||||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
|
||||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
|
||||||
|
|
||||||
# Create an agent with Brave search tool
|
|
||||||
agent_config = AgentConfig(
|
|
||||||
model=agents_settings["common_params"]["model"],
|
|
||||||
instructions=agents_settings["common_params"]["instructions"],
|
|
||||||
enable_session_persistence=True,
|
|
||||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
|
||||||
input_shields=[],
|
|
||||||
output_shields=[],
|
|
||||||
tools=[
|
|
||||||
SearchToolDefinition(
|
|
||||||
type=AgentTool.brave_search.value,
|
|
||||||
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
|
|
||||||
engine=SearchEngineType.brave,
|
|
||||||
)
|
)
|
||||||
],
|
for i, url in enumerate(urls)
|
||||||
tool_choice=ToolChoice.auto,
|
]
|
||||||
max_infer_iters=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
create_response = await agents_impl.create_agent(agent_config)
|
agent_config = AgentConfig(
|
||||||
agent_id = create_response.agent_id
|
model=inference_model,
|
||||||
|
instructions=common_params["instructions"],
|
||||||
|
enable_session_persistence=True,
|
||||||
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||||
|
input_shields=[],
|
||||||
|
output_shields=[],
|
||||||
|
tools=[
|
||||||
|
MemoryToolDefinition(
|
||||||
|
memory_bank_configs=[],
|
||||||
|
query_generator_config={
|
||||||
|
"type": "default",
|
||||||
|
"sep": " ",
|
||||||
|
},
|
||||||
|
max_tokens_in_context=4096,
|
||||||
|
max_chunks=10,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
max_infer_iters=5,
|
||||||
|
)
|
||||||
|
|
||||||
# Create a session
|
create_response = await agents_impl.create_agent(agent_config)
|
||||||
session_create_response = await agents_impl.create_agent_session(
|
agent_id = create_response.agent_id
|
||||||
agent_id, "Test Session with Brave Search"
|
|
||||||
)
|
|
||||||
session_id = session_create_response.session_id
|
|
||||||
|
|
||||||
# Create and execute a turn
|
# Create a session
|
||||||
turn_request = dict(
|
session_create_response = await agents_impl.create_agent_session(
|
||||||
agent_id=agent_id,
|
agent_id, "Test Session"
|
||||||
session_id=session_id,
|
)
|
||||||
messages=search_query_messages,
|
session_id = session_create_response.session_id
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
turn_response = [
|
# Create and execute a turn
|
||||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
turn_request = dict(
|
||||||
]
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=attachment_message,
|
||||||
|
attachments=attachments,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
turn_response = [
|
||||||
assert all(
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
]
|
||||||
)
|
|
||||||
|
|
||||||
# Check for expected event types
|
assert len(turn_response) > 0
|
||||||
|
|
||||||
|
# Create a second turn querying the agent
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=query_attachment_messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_response = [
|
||||||
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(turn_response) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_turn_with_brave_search(
|
||||||
|
self, agents_stack, search_query_messages, common_params, inference_model
|
||||||
|
):
|
||||||
|
agents_impl, _ = agents_stack
|
||||||
|
|
||||||
|
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||||
|
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||||
|
|
||||||
|
# Create an agent with Brave search tool
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model=inference_model,
|
||||||
|
instructions=common_params["instructions"],
|
||||||
|
enable_session_persistence=True,
|
||||||
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||||
|
input_shields=[],
|
||||||
|
output_shields=[],
|
||||||
|
tools=[
|
||||||
|
SearchToolDefinition(
|
||||||
|
type=AgentTool.brave_search.value,
|
||||||
|
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
|
||||||
|
engine=SearchEngineType.brave,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tool_choice=ToolChoice.auto,
|
||||||
|
max_infer_iters=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
create_response = await agents_impl.create_agent(agent_config)
|
||||||
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
session_create_response = await agents_impl.create_agent_session(
|
||||||
|
agent_id, "Test Session with Brave Search"
|
||||||
|
)
|
||||||
|
session_id = session_create_response.session_id
|
||||||
|
|
||||||
|
# Create and execute a turn
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=search_query_messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_response = [
|
||||||
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(turn_response) > 0
|
||||||
|
assert all(
|
||||||
|
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for expected event types
|
||||||
|
check_event_types(turn_response)
|
||||||
|
|
||||||
|
# Check for tool execution events
|
||||||
|
tool_execution_events = [
|
||||||
|
chunk
|
||||||
|
for chunk in turn_response
|
||||||
|
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||||
|
and chunk.event.payload.step_details.step_type
|
||||||
|
== StepType.tool_execution.value
|
||||||
|
]
|
||||||
|
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||||
|
|
||||||
|
# Check the tool execution details
|
||||||
|
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||||
|
assert isinstance(tool_execution, ToolExecutionStep)
|
||||||
|
assert len(tool_execution.tool_calls) > 0
|
||||||
|
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
|
||||||
|
assert len(tool_execution.tool_responses) > 0
|
||||||
|
|
||||||
|
# Check the final turn complete event
|
||||||
|
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||||
|
|
||||||
|
|
||||||
|
def check_event_types(turn_response):
|
||||||
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
||||||
assert AgentTurnResponseEventType.turn_start.value in event_types
|
assert AgentTurnResponseEventType.turn_start.value in event_types
|
||||||
assert AgentTurnResponseEventType.step_start.value in event_types
|
assert AgentTurnResponseEventType.step_start.value in event_types
|
||||||
assert AgentTurnResponseEventType.step_complete.value in event_types
|
assert AgentTurnResponseEventType.step_complete.value in event_types
|
||||||
assert AgentTurnResponseEventType.turn_complete.value in event_types
|
assert AgentTurnResponseEventType.turn_complete.value in event_types
|
||||||
|
|
||||||
# Check for tool execution events
|
|
||||||
tool_execution_events = [
|
|
||||||
chunk
|
|
||||||
for chunk in turn_response
|
|
||||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
|
||||||
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
|
|
||||||
]
|
|
||||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
|
||||||
|
|
||||||
# Check the tool execution details
|
def check_turn_complete_event(turn_response, session_id, input_messages):
|
||||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
|
||||||
assert isinstance(tool_execution, ToolExecutionStep)
|
|
||||||
assert len(tool_execution.tool_calls) > 0
|
|
||||||
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
|
|
||||||
assert len(tool_execution.tool_responses) > 0
|
|
||||||
|
|
||||||
# Check the final turn complete event
|
|
||||||
final_event = turn_response[-1].event.payload
|
final_event = turn_response[-1].event.payload
|
||||||
assert isinstance(final_event, AgentTurnResponseTurnCompletePayload)
|
assert isinstance(final_event, AgentTurnResponseTurnCompletePayload)
|
||||||
assert isinstance(final_event.turn, Turn)
|
assert isinstance(final_event.turn, Turn)
|
||||||
assert final_event.turn.session_id == session_id
|
assert final_event.turn.session_id == session_id
|
||||||
assert final_event.turn.input_messages == search_query_messages
|
assert final_event.turn.input_messages == input_messages
|
||||||
assert isinstance(final_event.turn.output_message, CompletionMessage)
|
assert isinstance(final_event.turn.output_message, CompletionMessage)
|
||||||
assert len(final_event.turn.output_message.content) > 0
|
assert len(final_event.turn.output_message.content) > 0
|
||||||
|
|
|
@ -130,4 +130,5 @@ pytest_plugins = [
|
||||||
"llama_stack.providers.tests.inference.fixtures",
|
"llama_stack.providers.tests.inference.fixtures",
|
||||||
"llama_stack.providers.tests.safety.fixtures",
|
"llama_stack.providers.tests.safety.fixtures",
|
||||||
"llama_stack.providers.tests.memory.fixtures",
|
"llama_stack.providers.tests.memory.fixtures",
|
||||||
|
"llama_stack.providers.tests.agents.fixtures",
|
||||||
]
|
]
|
||||||
|
|
|
@ -18,9 +18,8 @@ from .fixtures import INFERENCE_FIXTURES, MODEL_PARAMS
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
# pytest llama_stack/providers/tests/inference/test_inference.py
|
# pytest -v -s llama_stack/providers/tests/inference/test_inference.py
|
||||||
# -m "(fireworks or ollama) and llama_3b"
|
# -m "(fireworks or ollama) and llama_3b"
|
||||||
# -v -s --tb=short --disable-warnings
|
|
||||||
# --env FIREWORKS_API_KEY=<your_api_key>
|
# --env FIREWORKS_API_KEY=<your_api_key>
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from ..env import get_env_or_fail
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def meta_reference() -> ProviderFixture:
|
def memory_meta_reference() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
provider=Provider(
|
||||||
provider_id="meta-reference",
|
provider_id="meta-reference",
|
||||||
|
@ -31,7 +31,7 @@ def meta_reference() -> ProviderFixture:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def pgvector() -> ProviderFixture:
|
def memory_pgvector() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
provider=Provider(
|
||||||
provider_id="pgvector",
|
provider_id="pgvector",
|
||||||
|
@ -48,7 +48,7 @@ def pgvector() -> ProviderFixture:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def weaviate() -> ProviderFixture:
|
def memory_weaviate() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
provider=Provider(
|
provider=Provider(
|
||||||
provider_id="weaviate",
|
provider_id="weaviate",
|
||||||
|
@ -76,7 +76,7 @@ PROVIDER_PARAMS = [
|
||||||
)
|
)
|
||||||
async def memory_stack(request):
|
async def memory_stack(request):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
fixture = request.getfixturevalue(fixture_name)
|
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
impls = await resolve_impls_for_test_v2(
|
||||||
[Api.memory],
|
[Api.memory],
|
||||||
|
|
|
@ -11,6 +11,11 @@ from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# pytest -v -s llama_stack/providers/tests/safety/test_safety.py
|
||||||
|
# -m "ollama"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"inference_model", [pytest.param("Llama-Guard-3-1B", id="")], indirect=True
|
"inference_model", [pytest.param("Llama-Guard-3-1B", id="")], indirect=True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue