mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
Get agents tests working
This commit is contained in:
parent
62dd3b376c
commit
66b658dcce
8 changed files with 352 additions and 269 deletions
71
llama_stack/providers/tests/agents/conftest.py
Normal file
71
llama_stack/providers/tests/agents/conftest.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..memory.fixtures import MEMORY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES
|
||||
from .fixtures import AGENTS_FIXTURES
|
||||
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "meta_reference",
|
||||
"safety": "meta_reference",
|
||||
"memory": "meta_reference",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
id="meta_reference",
|
||||
marks=pytest.mark.meta_reference,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"safety": "meta_reference",
|
||||
"memory": "meta_reference",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
id="ollama",
|
||||
marks=pytest.mark.ollama,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "together",
|
||||
"safety": "together",
|
||||
"memory": "meta_reference",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for mark in ["meta_reference", "ollama", "together"]:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{mark}: marks tests as {mark} specific",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "agents_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"safety": SAFETY_FIXTURES,
|
||||
"memory": MEMORY_FIXTURES,
|
||||
"agents": AGENTS_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||
or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
metafunc.parametrize("agents_stack", combinations, indirect=True)
|
61
llama_stack/providers/tests/agents/fixtures.py
Normal file
61
llama_stack/providers/tests/agents/fixtures.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.agents import (
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from ..conftest import ProviderFixture
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def agents_meta_reference(inference_model, safety_model) -> ProviderFixture:
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
config=MetaReferenceAgentsImplConfig(
|
||||
# TODO: make this an in-memory store
|
||||
persistence_store=SqliteKVStoreConfig(
|
||||
db_path=sqlite_file.name,
|
||||
),
|
||||
).model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
AGENTS_FIXTURES = ["meta_reference"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def agents_stack(inference_model, safety_model, request):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["agents", "inference", "safety", "memory"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = [fixture.provider]
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||
providers,
|
||||
provider_data,
|
||||
)
|
||||
return impls[Api.agents], impls[Api.memory]
|
|
@ -1,34 +0,0 @@
|
|||
providers:
|
||||
inference:
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config: {}
|
||||
- provider_id: tgi
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: http://127.0.0.1:7001
|
||||
# - provider_id: meta-reference
|
||||
# provider_type: meta-reference
|
||||
# config:
|
||||
# model: Llama-Guard-3-1B
|
||||
# - provider_id: remote
|
||||
# provider_type: remote
|
||||
# config:
|
||||
# host: localhost
|
||||
# port: 7010
|
||||
safety:
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config: {}
|
||||
memory:
|
||||
- provider_id: faiss
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
|
@ -7,47 +7,20 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# 1. Ensure you have a conda environment with the right dependencies installed.
|
||||
# This includes `pytest` and `pytest-asyncio`.
|
||||
#
|
||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||
#
|
||||
# 3. Run:
|
||||
#
|
||||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# MODEL_ID=<your_model> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/agents/test_agents.py \
|
||||
# --tb=short --disable-warnings
|
||||
# ```
|
||||
|
||||
load_dotenv()
|
||||
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
||||
# -m "meta_reference"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def agents_settings():
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.agents, deps=[Api.inference, Api.memory, Api.safety]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def common_params():
|
||||
return {
|
||||
"impl": impls[Api.agents],
|
||||
"memory_impl": impls[Api.memory],
|
||||
"common_params": {
|
||||
"model": os.environ["MODEL_ID"] or "Llama3.1-8B-Instruct",
|
||||
"instructions": "You are a helpful assistant.",
|
||||
},
|
||||
"instructions": "You are a helpful assistant.",
|
||||
}
|
||||
|
||||
|
||||
|
@ -83,230 +56,237 @@ def query_attachment_messages():
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn(agents_settings, sample_messages):
|
||||
agents_impl = agents_settings["impl"]
|
||||
@pytest.mark.parametrize(
|
||||
"inference_model", [pytest.param("Llama3.1-8B-Instruct", id="")], indirect=True
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"safety_model", [pytest.param("Llama-Guard-3-8B", id="")], indirect=True
|
||||
)
|
||||
class TestAgents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn(
|
||||
self, agents_stack, sample_messages, common_params, inference_model
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
|
||||
# First, create an agent
|
||||
agent_config = AgentConfig(
|
||||
model=agents_settings["common_params"]["model"],
|
||||
instructions=agents_settings["common_params"]["instructions"],
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
# Create and execute a turn
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||
)
|
||||
|
||||
# Check for expected event types
|
||||
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
||||
assert AgentTurnResponseEventType.turn_start.value in event_types
|
||||
assert AgentTurnResponseEventType.step_start.value in event_types
|
||||
assert AgentTurnResponseEventType.step_complete.value in event_types
|
||||
assert AgentTurnResponseEventType.turn_complete.value in event_types
|
||||
|
||||
# Check the final turn complete event
|
||||
final_event = turn_response[-1].event.payload
|
||||
assert isinstance(final_event, AgentTurnResponseTurnCompletePayload)
|
||||
assert isinstance(final_event.turn, Turn)
|
||||
assert final_event.turn.session_id == session_id
|
||||
assert final_event.turn.input_messages == sample_messages
|
||||
assert isinstance(final_event.turn.output_message, CompletionMessage)
|
||||
assert len(final_event.turn.output_message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_agent_as_attachments(
|
||||
agents_settings, attachment_message, query_attachment_messages
|
||||
):
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
|
||||
attachments = [
|
||||
Attachment(
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
# First, create an agent
|
||||
agent_config = AgentConfig(
|
||||
model=inference_model,
|
||||
instructions=common_params["instructions"],
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
agents_impl = agents_settings["impl"]
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
agent_config = AgentConfig(
|
||||
model=agents_settings["common_params"]["model"],
|
||||
instructions=agents_settings["common_params"]["instructions"],
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[
|
||||
MemoryToolDefinition(
|
||||
memory_bank_configs=[],
|
||||
query_generator_config={
|
||||
"type": "default",
|
||||
"sep": " ",
|
||||
},
|
||||
max_tokens_in_context=4096,
|
||||
max_chunks=10,
|
||||
),
|
||||
],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
# Create and execute a turn
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
# Create and execute a turn
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=attachment_message,
|
||||
attachments=attachments,
|
||||
stream=True,
|
||||
)
|
||||
assert len(turn_response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
check_event_types(turn_response)
|
||||
check_turn_complete_event(turn_response, session_id, sample_messages)
|
||||
|
||||
assert len(turn_response) > 0
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_agent_as_attachments(
|
||||
self,
|
||||
agents_stack,
|
||||
attachment_message,
|
||||
query_attachment_messages,
|
||||
inference_model,
|
||||
common_params,
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
|
||||
# Create a second turn querying the agent
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=query_attachment_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_brave_search(
|
||||
agents_settings, search_query_messages
|
||||
):
|
||||
agents_impl = agents_settings["impl"]
|
||||
|
||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
# Create an agent with Brave search tool
|
||||
agent_config = AgentConfig(
|
||||
model=agents_settings["common_params"]["model"],
|
||||
instructions=agents_settings["common_params"]["instructions"],
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[
|
||||
SearchToolDefinition(
|
||||
type=AgentTool.brave_search.value,
|
||||
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
|
||||
engine=SearchEngineType.brave,
|
||||
attachments = [
|
||||
Attachment(
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
],
|
||||
tool_choice=ToolChoice.auto,
|
||||
max_infer_iters=5,
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
agent_config = AgentConfig(
|
||||
model=inference_model,
|
||||
instructions=common_params["instructions"],
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[
|
||||
MemoryToolDefinition(
|
||||
memory_bank_configs=[],
|
||||
query_generator_config={
|
||||
"type": "default",
|
||||
"sep": " ",
|
||||
},
|
||||
max_tokens_in_context=4096,
|
||||
max_chunks=10,
|
||||
),
|
||||
],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session with Brave Search"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create and execute a turn
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=search_query_messages,
|
||||
stream=True,
|
||||
)
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
# Create and execute a turn
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=attachment_message,
|
||||
attachments=attachments,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert len(turn_response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||
)
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
# Check for expected event types
|
||||
assert len(turn_response) > 0
|
||||
|
||||
# Create a second turn querying the agent
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=query_attachment_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_brave_search(
|
||||
self, agents_stack, search_query_messages, common_params, inference_model
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
|
||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
# Create an agent with Brave search tool
|
||||
agent_config = AgentConfig(
|
||||
model=inference_model,
|
||||
instructions=common_params["instructions"],
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[
|
||||
SearchToolDefinition(
|
||||
type=AgentTool.brave_search.value,
|
||||
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
|
||||
engine=SearchEngineType.brave,
|
||||
)
|
||||
],
|
||||
tool_choice=ToolChoice.auto,
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session with Brave Search"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
# Create and execute a turn
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=search_query_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||
)
|
||||
|
||||
# Check for expected event types
|
||||
check_event_types(turn_response)
|
||||
|
||||
# Check for tool execution events
|
||||
tool_execution_events = [
|
||||
chunk
|
||||
for chunk in turn_response
|
||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||
and chunk.event.payload.step_details.step_type
|
||||
== StepType.tool_execution.value
|
||||
]
|
||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||
|
||||
# Check the tool execution details
|
||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||
assert isinstance(tool_execution, ToolExecutionStep)
|
||||
assert len(tool_execution.tool_calls) > 0
|
||||
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
|
||||
assert len(tool_execution.tool_responses) > 0
|
||||
|
||||
# Check the final turn complete event
|
||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||
|
||||
|
||||
def check_event_types(turn_response):
|
||||
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
||||
assert AgentTurnResponseEventType.turn_start.value in event_types
|
||||
assert AgentTurnResponseEventType.step_start.value in event_types
|
||||
assert AgentTurnResponseEventType.step_complete.value in event_types
|
||||
assert AgentTurnResponseEventType.turn_complete.value in event_types
|
||||
|
||||
# Check for tool execution events
|
||||
tool_execution_events = [
|
||||
chunk
|
||||
for chunk in turn_response
|
||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
|
||||
]
|
||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||
|
||||
# Check the tool execution details
|
||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||
assert isinstance(tool_execution, ToolExecutionStep)
|
||||
assert len(tool_execution.tool_calls) > 0
|
||||
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
|
||||
assert len(tool_execution.tool_responses) > 0
|
||||
|
||||
# Check the final turn complete event
|
||||
def check_turn_complete_event(turn_response, session_id, input_messages):
|
||||
final_event = turn_response[-1].event.payload
|
||||
assert isinstance(final_event, AgentTurnResponseTurnCompletePayload)
|
||||
assert isinstance(final_event.turn, Turn)
|
||||
assert final_event.turn.session_id == session_id
|
||||
assert final_event.turn.input_messages == search_query_messages
|
||||
assert final_event.turn.input_messages == input_messages
|
||||
assert isinstance(final_event.turn.output_message, CompletionMessage)
|
||||
assert len(final_event.turn.output_message.content) > 0
|
||||
|
|
|
@ -130,4 +130,5 @@ pytest_plugins = [
|
|||
"llama_stack.providers.tests.inference.fixtures",
|
||||
"llama_stack.providers.tests.safety.fixtures",
|
||||
"llama_stack.providers.tests.memory.fixtures",
|
||||
"llama_stack.providers.tests.agents.fixtures",
|
||||
]
|
||||
|
|
|
@ -18,9 +18,8 @@ from .fixtures import INFERENCE_FIXTURES, MODEL_PARAMS
|
|||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/inference/test_inference.py
|
||||
# pytest -v -s llama_stack/providers/tests/inference/test_inference.py
|
||||
# -m "(fireworks or ollama) and llama_3b"
|
||||
# -v -s --tb=short --disable-warnings
|
||||
# --env FIREWORKS_API_KEY=<your_api_key>
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from ..env import get_env_or_fail
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def meta_reference() -> ProviderFixture:
|
||||
def memory_meta_reference() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="meta-reference",
|
||||
|
@ -31,7 +31,7 @@ def meta_reference() -> ProviderFixture:
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def pgvector() -> ProviderFixture:
|
||||
def memory_pgvector() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="pgvector",
|
||||
|
@ -48,7 +48,7 @@ def pgvector() -> ProviderFixture:
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def weaviate() -> ProviderFixture:
|
||||
def memory_weaviate() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="weaviate",
|
||||
|
@ -76,7 +76,7 @@ PROVIDER_PARAMS = [
|
|||
)
|
||||
async def memory_stack(request):
|
||||
fixture_name = request.param
|
||||
fixture = request.getfixturevalue(fixture_name)
|
||||
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
||||
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
[Api.memory],
|
||||
|
|
|
@ -11,6 +11,11 @@ from llama_stack.apis.safety import * # noqa: F403
|
|||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest -v -s llama_stack/providers/tests/safety/test_safety.py
|
||||
# -m "ollama"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inference_model", [pytest.param("Llama-Guard-3-1B", id="")], indirect=True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue