From 66b658dcce00f534e4ee9e9cba586bfc5847bf55 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 4 Nov 2024 10:14:55 -0800 Subject: [PATCH] Get agents tests working --- .../providers/tests/agents/conftest.py | 71 +++ .../providers/tests/agents/fixtures.py | 61 +++ .../tests/agents/provider_config_example.yaml | 34 -- .../providers/tests/agents/test_agents.py | 438 +++++++++--------- llama_stack/providers/tests/conftest.py | 1 + .../tests/inference/test_inference.py | 3 +- .../providers/tests/memory/fixtures.py | 8 +- .../providers/tests/safety/test_safety.py | 5 + 8 files changed, 352 insertions(+), 269 deletions(-) create mode 100644 llama_stack/providers/tests/agents/conftest.py create mode 100644 llama_stack/providers/tests/agents/fixtures.py delete mode 100644 llama_stack/providers/tests/agents/provider_config_example.yaml diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py new file mode 100644 index 000000000..b43bc9327 --- /dev/null +++ b/llama_stack/providers/tests/agents/conftest.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from ..conftest import get_provider_fixture_overrides + +from ..inference.fixtures import INFERENCE_FIXTURES +from ..memory.fixtures import MEMORY_FIXTURES +from ..safety.fixtures import SAFETY_FIXTURES +from .fixtures import AGENTS_FIXTURES + + +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "inference": "meta_reference", + "safety": "meta_reference", + "memory": "meta_reference", + "agents": "meta_reference", + }, + id="meta_reference", + marks=pytest.mark.meta_reference, + ), + pytest.param( + { + "inference": "ollama", + "safety": "meta_reference", + "memory": "meta_reference", + "agents": "meta_reference", + }, + id="ollama", + marks=pytest.mark.ollama, + ), + pytest.param( + { + "inference": "together", + "safety": "together", + "memory": "meta_reference", + "agents": "meta_reference", + }, + id="together", + marks=pytest.mark.together, + ), +] + + +def pytest_configure(config): + for mark in ["meta_reference", "ollama", "together"]: + config.addinivalue_line( + "markers", + f"{mark}: marks tests as {mark} specific", + ) + + +def pytest_generate_tests(metafunc): + if "agents_stack" in metafunc.fixturenames: + available_fixtures = { + "inference": INFERENCE_FIXTURES, + "safety": SAFETY_FIXTURES, + "memory": MEMORY_FIXTURES, + "agents": AGENTS_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS + ) + metafunc.parametrize("agents_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py new file mode 100644 index 000000000..5597f47e9 --- /dev/null +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import tempfile + +import pytest +import pytest_asyncio + +from llama_stack.distribution.datatypes import Api, Provider + +from llama_stack.providers.impls.meta_reference.agents import ( + MetaReferenceAgentsImplConfig, +) + +from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + +from ..conftest import ProviderFixture + + +@pytest.fixture(scope="session") +def agents_meta_reference(inference_model, safety_model) -> ProviderFixture: + sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + return ProviderFixture( + provider=Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config=MetaReferenceAgentsImplConfig( + # TODO: make this an in-memory store + persistence_store=SqliteKVStoreConfig( + db_path=sqlite_file.name, + ), + ).model_dump(), + ), + ) + + +AGENTS_FIXTURES = ["meta_reference"] + + +@pytest_asyncio.fixture(scope="session") +async def agents_stack(inference_model, safety_model, request): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["agents", "inference", "safety", "memory"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = [fixture.provider] + if fixture.provider_data: + provider_data.update(fixture.provider_data) + + impls = await resolve_impls_for_test_v2( + [Api.agents, Api.inference, Api.safety, Api.memory], + providers, + provider_data, + ) + return impls[Api.agents], impls[Api.memory] diff --git a/llama_stack/providers/tests/agents/provider_config_example.yaml b/llama_stack/providers/tests/agents/provider_config_example.yaml deleted file mode 100644 index 58f05e29a..000000000 --- a/llama_stack/providers/tests/agents/provider_config_example.yaml +++ /dev/null @@ -1,34 +0,0 @@ -providers: - inference: - - provider_id: together - provider_type: remote::together - config: {} - - provider_id: tgi - provider_type: remote::tgi - config: - url: http://127.0.0.1:7001 -# - provider_id: meta-reference -# provider_type: meta-reference -# config: -# model: Llama-Guard-3-1B -# - provider_id: remote -# provider_type: remote -# config: -# host: localhost -# port: 7010 - safety: - - provider_id: together - provider_type: remote::together - config: {} - memory: - - provider_id: faiss - provider_type: meta-reference - config: {} - agents: - - provider_id: meta-reference - provider_type: meta-reference - config: - persistence_store: - namespace: null - type: sqlite - db_path: ~/.llama/runtime/kvstore.db diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index c09db3d20..21035ba44 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -7,47 +7,20 @@ import os import pytest -import pytest_asyncio from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.providers.tests.resolver import resolve_impls_for_test from llama_stack.providers.datatypes import * # noqa: F403 -from dotenv import load_dotenv - # How to run this test: # -# 1. Ensure you have a conda environment with the right dependencies installed. -# This includes `pytest` and `pytest-asyncio`. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# MODEL_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/agents/test_agents.py \ -# --tb=short --disable-warnings -# ``` - -load_dotenv() +# pytest -v -s llama_stack/providers/tests/agents/test_agents.py +# -m "meta_reference" -@pytest_asyncio.fixture(scope="session") -async def agents_settings(): - impls = await resolve_impls_for_test( - Api.agents, deps=[Api.inference, Api.memory, Api.safety] - ) - +@pytest.fixture +def common_params(): return { - "impl": impls[Api.agents], - "memory_impl": impls[Api.memory], - "common_params": { - "model": os.environ["MODEL_ID"] or "Llama3.1-8B-Instruct", - "instructions": "You are a helpful assistant.", - }, + "instructions": "You are a helpful assistant.", } @@ -83,230 +56,237 @@ def query_attachment_messages(): ] -@pytest.mark.asyncio -async def test_create_agent_turn(agents_settings, sample_messages): - agents_impl = agents_settings["impl"] +@pytest.mark.parametrize( + "inference_model", [pytest.param("Llama3.1-8B-Instruct", id="")], indirect=True +) +@pytest.mark.parametrize( + "safety_model", [pytest.param("Llama-Guard-3-8B", id="")], indirect=True +) +class TestAgents: + @pytest.mark.asyncio + async def test_create_agent_turn( + self, agents_stack, sample_messages, common_params, inference_model + ): + agents_impl, _ = agents_stack - # First, create an agent - agent_config = AgentConfig( - model=agents_settings["common_params"]["model"], - instructions=agents_settings["common_params"]["instructions"], - enable_session_persistence=True, - sampling_params=SamplingParams(temperature=0.7, top_p=0.95), - input_shields=[], - output_shields=[], - tools=[], - max_infer_iters=5, - ) - - create_response = await agents_impl.create_agent(agent_config) - agent_id = create_response.agent_id - - # Create a session - session_create_response = await agents_impl.create_agent_session( - agent_id, "Test Session" - ) - session_id = session_create_response.session_id - - # Create and execute a turn - turn_request = dict( - agent_id=agent_id, - session_id=session_id, - messages=sample_messages, - stream=True, - ) - - turn_response = [ - chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) - ] - - assert len(turn_response) > 0 - assert all( - isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response - ) - - # Check for expected event types - event_types = [chunk.event.payload.event_type for chunk in turn_response] - assert AgentTurnResponseEventType.turn_start.value in event_types - assert AgentTurnResponseEventType.step_start.value in event_types - assert AgentTurnResponseEventType.step_complete.value in event_types - assert AgentTurnResponseEventType.turn_complete.value in event_types - - # Check the final turn complete event - final_event = turn_response[-1].event.payload - assert isinstance(final_event, AgentTurnResponseTurnCompletePayload) - assert isinstance(final_event.turn, Turn) - assert final_event.turn.session_id == session_id - assert final_event.turn.input_messages == sample_messages - assert isinstance(final_event.turn.output_message, CompletionMessage) - assert len(final_event.turn.output_message.content) > 0 - - -@pytest.mark.asyncio -async def test_rag_agent_as_attachments( - agents_settings, attachment_message, query_attachment_messages -): - urls = [ - "memory_optimizations.rst", - "chat.rst", - "llama3.rst", - "datasets.rst", - "qat_finetune.rst", - "lora_finetune.rst", - ] - - attachments = [ - Attachment( - content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", - mime_type="text/plain", + # First, create an agent + agent_config = AgentConfig( + model=inference_model, + instructions=common_params["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[], + max_infer_iters=5, ) - for i, url in enumerate(urls) - ] - agents_impl = agents_settings["impl"] + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id - agent_config = AgentConfig( - model=agents_settings["common_params"]["model"], - instructions=agents_settings["common_params"]["instructions"], - enable_session_persistence=True, - sampling_params=SamplingParams(temperature=0.7, top_p=0.95), - input_shields=[], - output_shields=[], - tools=[ - MemoryToolDefinition( - memory_bank_configs=[], - query_generator_config={ - "type": "default", - "sep": " ", - }, - max_tokens_in_context=4096, - max_chunks=10, - ), - ], - max_infer_iters=5, - ) + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id - create_response = await agents_impl.create_agent(agent_config) - agent_id = create_response.agent_id + # Create and execute a turn + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=sample_messages, + stream=True, + ) - # Create a session - session_create_response = await agents_impl.create_agent_session( - agent_id, "Test Session" - ) - session_id = session_create_response.session_id + turn_response = [ + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) + ] - # Create and execute a turn - turn_request = dict( - agent_id=agent_id, - session_id=session_id, - messages=attachment_message, - attachments=attachments, - stream=True, - ) + assert len(turn_response) > 0 + assert all( + isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response + ) - turn_response = [ - chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) - ] + check_event_types(turn_response) + check_turn_complete_event(turn_response, session_id, sample_messages) - assert len(turn_response) > 0 + @pytest.mark.asyncio + async def test_rag_agent_as_attachments( + self, + agents_stack, + attachment_message, + query_attachment_messages, + inference_model, + common_params, + ): + agents_impl, _ = agents_stack + urls = [ + "memory_optimizations.rst", + "chat.rst", + "llama3.rst", + "datasets.rst", + "qat_finetune.rst", + "lora_finetune.rst", + ] - # Create a second turn querying the agent - turn_request = dict( - agent_id=agent_id, - session_id=session_id, - messages=query_attachment_messages, - stream=True, - ) - - turn_response = [ - chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) - ] - - assert len(turn_response) > 0 - - -@pytest.mark.asyncio -async def test_create_agent_turn_with_brave_search( - agents_settings, search_query_messages -): - agents_impl = agents_settings["impl"] - - if "BRAVE_SEARCH_API_KEY" not in os.environ: - pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test") - - # Create an agent with Brave search tool - agent_config = AgentConfig( - model=agents_settings["common_params"]["model"], - instructions=agents_settings["common_params"]["instructions"], - enable_session_persistence=True, - sampling_params=SamplingParams(temperature=0.7, top_p=0.95), - input_shields=[], - output_shields=[], - tools=[ - SearchToolDefinition( - type=AgentTool.brave_search.value, - api_key=os.environ["BRAVE_SEARCH_API_KEY"], - engine=SearchEngineType.brave, + attachments = [ + Attachment( + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", ) - ], - tool_choice=ToolChoice.auto, - max_infer_iters=5, - ) + for i, url in enumerate(urls) + ] - create_response = await agents_impl.create_agent(agent_config) - agent_id = create_response.agent_id + agent_config = AgentConfig( + model=inference_model, + instructions=common_params["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[ + MemoryToolDefinition( + memory_bank_configs=[], + query_generator_config={ + "type": "default", + "sep": " ", + }, + max_tokens_in_context=4096, + max_chunks=10, + ), + ], + max_infer_iters=5, + ) - # Create a session - session_create_response = await agents_impl.create_agent_session( - agent_id, "Test Session with Brave Search" - ) - session_id = session_create_response.session_id + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id - # Create and execute a turn - turn_request = dict( - agent_id=agent_id, - session_id=session_id, - messages=search_query_messages, - stream=True, - ) + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id - turn_response = [ - chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) - ] + # Create and execute a turn + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=attachment_message, + attachments=attachments, + stream=True, + ) - assert len(turn_response) > 0 - assert all( - isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response - ) + turn_response = [ + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) + ] - # Check for expected event types + assert len(turn_response) > 0 + + # Create a second turn querying the agent + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=query_attachment_messages, + stream=True, + ) + + turn_response = [ + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) + ] + + assert len(turn_response) > 0 + + @pytest.mark.asyncio + async def test_create_agent_turn_with_brave_search( + self, agents_stack, search_query_messages, common_params, inference_model + ): + agents_impl, _ = agents_stack + + if "BRAVE_SEARCH_API_KEY" not in os.environ: + pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test") + + # Create an agent with Brave search tool + agent_config = AgentConfig( + model=inference_model, + instructions=common_params["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[ + SearchToolDefinition( + type=AgentTool.brave_search.value, + api_key=os.environ["BRAVE_SEARCH_API_KEY"], + engine=SearchEngineType.brave, + ) + ], + tool_choice=ToolChoice.auto, + max_infer_iters=5, + ) + + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session with Brave Search" + ) + session_id = session_create_response.session_id + + # Create and execute a turn + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=search_query_messages, + stream=True, + ) + + turn_response = [ + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) + ] + + assert len(turn_response) > 0 + assert all( + isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response + ) + + # Check for expected event types + check_event_types(turn_response) + + # Check for tool execution events + tool_execution_events = [ + chunk + for chunk in turn_response + if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload) + and chunk.event.payload.step_details.step_type + == StepType.tool_execution.value + ] + assert len(tool_execution_events) > 0, "No tool execution events found" + + # Check the tool execution details + tool_execution = tool_execution_events[0].event.payload.step_details + assert isinstance(tool_execution, ToolExecutionStep) + assert len(tool_execution.tool_calls) > 0 + assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search + assert len(tool_execution.tool_responses) > 0 + + # Check the final turn complete event + check_turn_complete_event(turn_response, session_id, search_query_messages) + + +def check_event_types(turn_response): event_types = [chunk.event.payload.event_type for chunk in turn_response] assert AgentTurnResponseEventType.turn_start.value in event_types assert AgentTurnResponseEventType.step_start.value in event_types assert AgentTurnResponseEventType.step_complete.value in event_types assert AgentTurnResponseEventType.turn_complete.value in event_types - # Check for tool execution events - tool_execution_events = [ - chunk - for chunk in turn_response - if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload) - and chunk.event.payload.step_details.step_type == StepType.tool_execution.value - ] - assert len(tool_execution_events) > 0, "No tool execution events found" - # Check the tool execution details - tool_execution = tool_execution_events[0].event.payload.step_details - assert isinstance(tool_execution, ToolExecutionStep) - assert len(tool_execution.tool_calls) > 0 - assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search - assert len(tool_execution.tool_responses) > 0 - - # Check the final turn complete event +def check_turn_complete_event(turn_response, session_id, input_messages): final_event = turn_response[-1].event.payload assert isinstance(final_event, AgentTurnResponseTurnCompletePayload) assert isinstance(final_event.turn, Turn) assert final_event.turn.session_id == session_id - assert final_event.turn.input_messages == search_query_messages + assert final_event.turn.input_messages == input_messages assert isinstance(final_event.turn.output_message, CompletionMessage) assert len(final_event.turn.output_message.content) > 0 diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 8647baddd..6861a29fd 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -130,4 +130,5 @@ pytest_plugins = [ "llama_stack.providers.tests.inference.fixtures", "llama_stack.providers.tests.safety.fixtures", "llama_stack.providers.tests.memory.fixtures", + "llama_stack.providers.tests.agents.fixtures", ] diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 9d961117e..4051cea69 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -18,9 +18,8 @@ from .fixtures import INFERENCE_FIXTURES, MODEL_PARAMS # How to run this test: # -# pytest llama_stack/providers/tests/inference/test_inference.py +# pytest -v -s llama_stack/providers/tests/inference/test_inference.py # -m "(fireworks or ollama) and llama_3b" -# -v -s --tb=short --disable-warnings # --env FIREWORKS_API_KEY= diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index 79ab2ffff..1f1050df4 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -20,7 +20,7 @@ from ..env import get_env_or_fail @pytest.fixture(scope="session") -def meta_reference() -> ProviderFixture: +def memory_meta_reference() -> ProviderFixture: return ProviderFixture( provider=Provider( provider_id="meta-reference", @@ -31,7 +31,7 @@ def meta_reference() -> ProviderFixture: @pytest.fixture(scope="session") -def pgvector() -> ProviderFixture: +def memory_pgvector() -> ProviderFixture: return ProviderFixture( provider=Provider( provider_id="pgvector", @@ -48,7 +48,7 @@ def pgvector() -> ProviderFixture: @pytest.fixture(scope="session") -def weaviate() -> ProviderFixture: +def memory_weaviate() -> ProviderFixture: return ProviderFixture( provider=Provider( provider_id="weaviate", @@ -76,7 +76,7 @@ PROVIDER_PARAMS = [ ) async def memory_stack(request): fixture_name = request.param - fixture = request.getfixturevalue(fixture_name) + fixture = request.getfixturevalue(f"memory_{fixture_name}") impls = await resolve_impls_for_test_v2( [Api.memory], diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index 760edcd36..e355d5908 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -11,6 +11,11 @@ from llama_stack.apis.safety import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +# How to run this test: +# +# pytest -v -s llama_stack/providers/tests/safety/test_safety.py +# -m "ollama" + @pytest.mark.parametrize( "inference_model", [pytest.param("Llama-Guard-3-1B", id="")], indirect=True