Added tests for persistence

This commit is contained in:
Sarthak Deshpande 2024-10-21 15:56:54 +05:30
parent cae5b0708b
commit 862ff6c653
7 changed files with 157 additions and 11 deletions

View file

@ -0,0 +1,146 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.tests.resolver import resolve_impls_for_test
from llama_stack.providers.datatypes import * # noqa: F403
from dotenv import load_dotenv
from llama_stack.providers.impls.meta_reference.agents.config import MetaReferenceAgentsImplConfig
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
# How to run this test:
#
# 1. Ensure you have a conda environment with the right dependencies installed.
# This includes `pytest` and `pytest-asyncio`.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \
# --tb=short --disable-warnings
# ```
load_dotenv()
@pytest_asyncio.fixture(scope="session")
async def agents_settings():
impls = await resolve_impls_for_test(
Api.agents, deps=[Api.inference, Api.memory, Api.safety]
)
return {
"impl": impls['impls'][Api.agents],
"memory_impl": impls['impls'][Api.memory],
"persistence": impls['persistence'],
"common_params": {
"model": "Llama3.1-8B-Instruct",
"instructions": "You are a helpful assistant.",
},
}
@pytest.fixture
def sample_messages():
return [
UserMessage(content="What's the weather like today?"),
]
@pytest.mark.asyncio
async def test_delete_agents_and_sessions(agents_settings, sample_messages):
agents_impl = agents_settings["impl"]
# First, create an agent
agent_config = AgentConfig(
model=agents_settings["common_params"]["model"],
instructions=agents_settings["common_params"]["instructions"],
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
max_infer_iters=5,
)
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
persistence_store = await kvstore_impl(agents_settings['persistence'])
await agents_impl.delete_agents_session(agent_id, session_id)
session_response = await persistence_store.get(f"session:{agent_id}:{session_id}")
await agents_impl.delete_agents(agent_id)
agent_response = await persistence_store.get(f"agent:{agent_id}")
assert session_response is None
assert agent_response is None
async def test_get_agent_turns_and_steps(agents_settings, sample_messages):
agents_impl = agents_settings["impl"]
# First, create an agent
agent_config = AgentConfig(
model=agents_settings["common_params"]["model"],
instructions=agents_settings["common_params"]["instructions"],
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
max_infer_iters=5,
)
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
# Create and execute a turn
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=sample_messages,
stream=True,
)
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
final_event = turn_response[-1].event.payload
turn_id = final_event.turn.turn_id
persistence_store = await kvstore_impl(agents_settings['persistence'])
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
assert isinstance(response, Turn)
assert response == final_event.turn
assert turn == final_event.turn
steps = final_event.turn.steps
step_id = steps[0].step_id
step_response = await agents_impl.get_agents_step(agent_id, session_id, turn_id, step_id)
assert isinstance(step_response.step, Step)
assert step_response.step == steps[0]

View file

@ -41,8 +41,8 @@ async def agents_settings():
)
return {
"impl": impls[Api.agents],
"memory_impl": impls[Api.memory],
"impl": impls['impls'][Api.agents],
"memory_impl": impls['impls'][Api.memory],
"common_params": {
"model": "Llama3.1-8B-Instruct",
"instructions": "You are a helpful assistant.",

View file

@ -71,8 +71,8 @@ async def inference_settings(request):
)
return {
"impl": impls[Api.inference],
"models_impl": impls[Api.models],
"impl": impls['impls'][Api.inference],
"models_impl": impls['impls'][Api.models],
"common_params": {
"model": model,
"tool_choice": ToolChoice.auto,

View file

@ -7,8 +7,8 @@
import unittest
from llama_models.llama3.api import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages
from llama_stack.apis.inference.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prompt_adapter import chat_completion_request_to_messages
MODEL = "Llama3.1-8B-Instruct"

View file

@ -36,8 +36,8 @@ async def memory_settings():
Api.memory,
)
return {
"memory_impl": impls[Api.memory],
"memory_banks_impl": impls[Api.memory_banks],
"memory_impl": impls['impls'][Api.memory],
"memory_banks_impl": impls['impls'][Api.memory_banks],
}

View file

@ -46,7 +46,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
return {"impls": impls, "persistence": config_dict['providers']['agents'][0]['config']['persistence_store']}
def read_providers(api: Api, config_dict: Dict[str, Any]) -> Dict[str, Any]:

View file

@ -36,8 +36,8 @@ async def safety_settings():
impls = await resolve_impls_for_test(Api.safety, deps=[Api.inference])
return {
"impl": impls[Api.safety],
"shields_impl": impls[Api.shields],
"impl": impls['impls'][Api.safety],
"shields_impl": impls['impls'][Api.shields],
}