From 862ff6c65322a146eda36bd891db1dc8c652820e Mon Sep 17 00:00:00 2001 From: Sarthak Deshpande Date: Mon, 21 Oct 2024 15:56:54 +0530 Subject: [PATCH] Added tests for persistence --- .../tests/agents/test_agent_persistence.py | 146 ++++++++++++++++++ .../providers/tests/agents/test_agents.py | 4 +- .../tests/inference/test_inference.py | 4 +- .../tests/inference/test_prompt_adapter.py | 4 +- .../providers/tests/memory/test_memory.py | 4 +- llama_stack/providers/tests/resolver.py | 2 +- .../providers/tests/safety/test_safety.py | 4 +- 7 files changed, 157 insertions(+), 11 deletions(-) create mode 100644 llama_stack/providers/tests/agents/test_agent_persistence.py diff --git a/llama_stack/providers/tests/agents/test_agent_persistence.py b/llama_stack/providers/tests/agents/test_agent_persistence.py new file mode 100644 index 000000000..c93c6548e --- /dev/null +++ b/llama_stack/providers/tests/agents/test_agent_persistence.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import pytest +import pytest_asyncio + +from llama_stack.apis.agents import * # noqa: F403 +from llama_stack.providers.tests.resolver import resolve_impls_for_test +from llama_stack.providers.datatypes import * # noqa: F403 + +from dotenv import load_dotenv +from llama_stack.providers.impls.meta_reference.agents.config import MetaReferenceAgentsImplConfig +from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl + +# How to run this test: +# +# 1. Ensure you have a conda environment with the right dependencies installed. +# This includes `pytest` and `pytest-asyncio`. +# +# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. +# +# 3. Run: +# +# ```bash +# PROVIDER_ID= \ +# PROVIDER_CONFIG=provider_config.yaml \ +# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \ +# --tb=short --disable-warnings +# ``` + +load_dotenv() + +@pytest_asyncio.fixture(scope="session") +async def agents_settings(): + impls = await resolve_impls_for_test( + Api.agents, deps=[Api.inference, Api.memory, Api.safety] + ) + + return { + "impl": impls['impls'][Api.agents], + "memory_impl": impls['impls'][Api.memory], + "persistence": impls['persistence'], + "common_params": { + "model": "Llama3.1-8B-Instruct", + "instructions": "You are a helpful assistant.", + }, + } + +@pytest.fixture +def sample_messages(): + return [ + UserMessage(content="What's the weather like today?"), + ] + +@pytest.mark.asyncio +async def test_delete_agents_and_sessions(agents_settings, sample_messages): + agents_impl = agents_settings["impl"] + # First, create an agent + agent_config = AgentConfig( + model=agents_settings["common_params"]["model"], + instructions=agents_settings["common_params"]["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[], + max_infer_iters=5, + ) + + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id + persistence_store = await kvstore_impl(agents_settings['persistence']) + + await agents_impl.delete_agents_session(agent_id, session_id) + session_response = await persistence_store.get(f"session:{agent_id}:{session_id}") + + await agents_impl.delete_agents(agent_id) + agent_response = await persistence_store.get(f"agent:{agent_id}") + + assert session_response is None + assert agent_response is None + + +async def test_get_agent_turns_and_steps(agents_settings, sample_messages): + agents_impl = agents_settings["impl"] + + # First, create an agent + agent_config = AgentConfig( + model=agents_settings["common_params"]["model"], + instructions=agents_settings["common_params"]["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[], + max_infer_iters=5, + ) + + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id + + # Create and execute a turn + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=sample_messages, + stream=True, + ) + + turn_response = [ + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) + ] + + final_event = turn_response[-1].event.payload + turn_id = final_event.turn.turn_id + persistence_store = await kvstore_impl(agents_settings['persistence']) + turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") + response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id) + + assert isinstance(response, Turn) + assert response == final_event.turn + assert turn == final_event.turn + + steps = final_event.turn.steps + step_id = steps[0].step_id + step_response = await agents_impl.get_agents_step(agent_id, session_id, turn_id, step_id) + + assert isinstance(step_response.step, Step) + assert step_response.step == steps[0] \ No newline at end of file diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 9c34c3a28..aa88f79df 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -41,8 +41,8 @@ async def agents_settings(): ) return { - "impl": impls[Api.agents], - "memory_impl": impls[Api.memory], + "impl": impls['impls'][Api.agents], + "memory_impl": impls['impls'][Api.memory], "common_params": { "model": "Llama3.1-8B-Instruct", "instructions": "You are a helpful assistant.", diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 09d6a69db..95e44fcb1 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -71,8 +71,8 @@ async def inference_settings(request): ) return { - "impl": impls[Api.inference], - "models_impl": impls[Api.models], + "impl": impls['impls'][Api.inference], + "models_impl": impls['impls'][Api.models], "common_params": { "model": model, "tool_choice": ToolChoice.auto, diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index 3a1e25d65..80a428c52 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -7,8 +7,8 @@ import unittest from llama_models.llama3.api import * # noqa: F403 -from llama_stack.inference.api import * # noqa: F403 -from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages +from llama_stack.apis.inference.inference import * # noqa: F403 +from llama_stack.providers.utils.inference.prompt_adapter import chat_completion_request_to_messages MODEL = "Llama3.1-8B-Instruct" diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index d92feaba8..e86cb6c9b 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -36,8 +36,8 @@ async def memory_settings(): Api.memory, ) return { - "memory_impl": impls[Api.memory], - "memory_banks_impl": impls[Api.memory_banks], + "memory_impl": impls['impls'][Api.memory], + "memory_banks_impl": impls['impls'][Api.memory_banks], } diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index de672b6dc..816dfdcf5 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -46,7 +46,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None): {"X-LlamaStack-ProviderData": json.dumps(provider_data)} ) - return impls + return {"impls": impls, "persistence": config_dict['providers']['agents'][0]['config']['persistence_store']} def read_providers(api: Api, config_dict: Dict[str, Any]) -> Dict[str, Any]: diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index 1861a7e8c..b951a2bc6 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -36,8 +36,8 @@ async def safety_settings(): impls = await resolve_impls_for_test(Api.safety, deps=[Api.inference]) return { - "impl": impls[Api.safety], - "shields_impl": impls[Api.shields], + "impl": impls['impls'][Api.safety], + "shields_impl": impls['impls'][Api.shields], }