diff --git a/llama_stack/providers/tests/agents/test_agent_persistence.py b/llama_stack/providers/tests/agents/test_agent_persistence.py new file mode 100644 index 000000000..a15887b33 --- /dev/null +++ b/llama_stack/providers/tests/agents/test_agent_persistence.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +import pytest_asyncio + +from llama_stack.apis.agents import * # noqa: F403 +from llama_stack.providers.tests.resolver import resolve_impls_for_test +from llama_stack.providers.datatypes import * # noqa: F403 + +from dotenv import load_dotenv + +from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig + +# How to run this test: +# +# 1. Ensure you have a conda environment with the right dependencies installed. +# This includes `pytest` and `pytest-asyncio`. +# +# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. +# +# 3. Run: +# +# ```bash +# PROVIDER_ID= \ +# PROVIDER_CONFIG=provider_config.yaml \ +# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \ +# --tb=short --disable-warnings +# ``` + +load_dotenv() + + +@pytest_asyncio.fixture(scope="session") +async def agents_settings(): + impls = await resolve_impls_for_test( + Api.agents, deps=[Api.inference, Api.memory, Api.safety] + ) + + return { + "impl": impls[Api.agents], + "memory_impl": impls[Api.memory], + "common_params": { + "model": "Llama3.1-8B-Instruct", + "instructions": "You are a helpful assistant.", + }, + } + + +@pytest.fixture +def sample_messages(): + return [ + UserMessage(content="What's the weather like today?"), + ] + + +@pytest.mark.asyncio +async def test_delete_agents_and_sessions(agents_settings, sample_messages): + agents_impl = agents_settings["impl"] + # First, create an agent + agent_config = AgentConfig( + model=agents_settings["common_params"]["model"], + instructions=agents_settings["common_params"]["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[], + max_infer_iters=5, + ) + + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id + persistence_store = await kvstore_impl(agents_settings["persistence"]) + + await agents_impl.delete_agents_session(agent_id, session_id) + session_response = await persistence_store.get(f"session:{agent_id}:{session_id}") + + await agents_impl.delete_agents(agent_id) + agent_response = await persistence_store.get(f"agent:{agent_id}") + + assert session_response is None + assert agent_response is None + + +async def test_get_agent_turns_and_steps(agents_settings, sample_messages): + agents_impl = agents_settings["impl"] + + # First, create an agent + agent_config = AgentConfig( + model=agents_settings["common_params"]["model"], + instructions=agents_settings["common_params"]["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[], + max_infer_iters=5, + ) + + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id + + # Create and execute a turn + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=sample_messages, + stream=True, + ) + + turn_response = [ + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) + ] + + final_event = turn_response[-1].event.payload + turn_id = final_event.turn.turn_id + persistence_store = await kvstore_impl(SqliteKVStoreConfig()) + turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") + response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id) + + assert isinstance(response, Turn) + assert response == final_event.turn + assert turn == final_event.turn + + steps = final_event.turn.steps + step_id = steps[0].step_id + step_response = await agents_impl.get_agents_step( + agent_id, session_id, turn_id, step_id + ) + + assert isinstance(step_response.step, Step) + assert step_response.step == steps[0] diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index 3a1e25d65..2c222ffa1 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -7,8 +7,10 @@ import unittest from llama_models.llama3.api import * # noqa: F403 -from llama_stack.inference.api import * # noqa: F403 -from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages +from llama_stack.apis.inference.inference import * # noqa: F403 +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_messages, +) MODEL = "Llama3.1-8B-Instruct"