llama-stack/llama_stack/providers/tests/agents/test_persistence.py
Sébastien Han e4a1579e63
build: format codebase imports using ruff linter (#1028)
# What does this PR do?

- Configured ruff linter to automatically fix import sorting issues.
- Set --exit-non-zero-on-fix to ensure non-zero exit code when fixes are
applied.
- Enabled the 'I' selection to focus on import-related linting rules.
- Ran the linter, and formatted all codebase imports accordingly.
- Removed the black dep from the "dev" group since we use ruff

Signed-off-by: Sébastien Han <seb@redhat.com>

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

[//]: # (## Documentation)
[//]: # (- [ ] Added a Changelog entry if the change is significant)

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-02-13 10:06:21 -08:00

111 lines
3.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents import AgentConfig, Turn
from llama_stack.apis.inference import SamplingParams, UserMessage
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from .fixtures import pick_inference_model
from .utils import create_agent_session
@pytest.fixture
def sample_messages():
return [
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def common_params(inference_model):
inference_model = pick_inference_model(inference_model)
return dict(
model=inference_model,
instructions="You are a helpful assistant.",
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
max_infer_iters=5,
)
class TestAgentPersistence:
@pytest.mark.asyncio
async def test_delete_agents_and_sessions(self, agents_stack, common_params):
agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session(
agents_impl,
AgentConfig(
**{
**common_params,
"input_shields": [],
"output_shields": [],
}
),
)
run_config = agents_stack.run_config
provider_config = run_config.providers["agents"][0].config
persistence_store = await kvstore_impl(SqliteKVStoreConfig(**provider_config["persistence_store"]))
await agents_impl.delete_agents_session(agent_id, session_id)
session_response = await persistence_store.get(f"session:{agent_id}:{session_id}")
await agents_impl.delete_agents(agent_id)
agent_response = await persistence_store.get(f"agent:{agent_id}")
assert session_response is None
assert agent_response is None
@pytest.mark.asyncio
async def test_get_agent_turns_and_steps(self, agents_stack, sample_messages, common_params):
agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session(
agents_impl,
AgentConfig(
**{
**common_params,
"input_shields": [],
"output_shields": [],
}
),
)
# Create and execute a turn
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=sample_messages,
stream=True,
)
turn_response = [chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)]
final_event = turn_response[-1].event.payload
turn_id = final_event.turn.turn_id
provider_config = agents_stack.run_config.providers["agents"][0].config
persistence_store = await kvstore_impl(SqliteKVStoreConfig(**provider_config["persistence_store"]))
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
assert isinstance(response, Turn)
assert response == final_event.turn
assert turn == final_event.turn.model_dump_json()
steps = final_event.turn.steps
step_id = steps[0].step_id
step_response = await agents_impl.get_agents_step(agent_id, session_id, turn_id, step_id)
assert step_response.step == steps[0]