mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Continues the refactor of tests. Tests from `providers/tests` should be considered deprecated. For this PR, I deleted most of the tests in - inference - safety - agents since much more comprehensive tests exist in `tests/integration/{inference,safety,agents}` already. I moved `test_persistence.py` from agents, but disabled all the tests since that test needs to be properly migrated. ## Test Plan ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v agents --vision-inference-model='' /Users/ashwin/homebrew/Caskroom/miniconda/base/envs/toolchain/lib/python3.10/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ======================================================================================================= test session starts ======================================================================================================== platform darwin -- Python 3.10.16, pytest-8.3.3, pluggy-1.5.0 -- /Users/ashwin/homebrew/Caskroom/miniconda/base/envs/toolchain/bin/python cachedir: .pytest_cache metadata: {'Python': '3.10.16', 'Platform': 'macOS-15.3.1-arm64-arm-64bit', 'Packages': {'pytest': '8.3.3', 'pluggy': '1.5.0'}, 'Plugins': {'asyncio': '0.24.0', 'html': '4.1.1', 'metadata': '3.1.1', 'anyio': '4.8.0', 'nbval': '0.11.0'}} rootdir: /Users/ashwin/local/llama-stack configfile: pyproject.toml plugins: asyncio-0.24.0, html-4.1.1, metadata-3.1.1, anyio-4.8.0, nbval-0.11.0 asyncio: mode=strict, default_loop_scope=None collected 15 items agents/test_agents.py::test_agent_simple[txt=8B] PASSED agents/test_agents.py::test_tool_config[txt=8B] PASSED agents/test_agents.py::test_builtin_tool_web_search[txt=8B] PASSED agents/test_agents.py::test_builtin_tool_code_execution[txt=8B] PASSED agents/test_agents.py::test_code_interpreter_for_attachments[txt=8B] PASSED agents/test_agents.py::test_custom_tool[txt=8B] PASSED agents/test_agents.py::test_custom_tool_infinite_loop[txt=8B] PASSED agents/test_agents.py::test_tool_choice[txt=8B] PASSED agents/test_agents.py::test_rag_agent[txt=8B-builtin::rag/knowledge_search] PASSED agents/test_agents.py::test_rag_agent[txt=8B-builtin::rag] PASSED agents/test_agents.py::test_rag_agent_with_attachments[txt=8B] PASSED agents/test_agents.py::test_rag_and_code_agent[txt=8B] PASSED agents/test_agents.py::test_create_turn_response[txt=8B] PASSED agents/test_persistence.py::test_delete_agents_and_sessions SKIPPED (This test needs to be migrated to api / client-sdk world) agents/test_persistence.py::test_get_agent_turns_and_steps SKIPPED (This test needs to be migrated to api / client-sdk world) ```
118 lines
3.8 KiB
Python
118 lines
3.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import pytest
|
|
|
|
from llama_stack.apis.agents import AgentConfig, Turn
|
|
from llama_stack.apis.inference import SamplingParams, UserMessage
|
|
from llama_stack.providers.datatypes import Api
|
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_messages():
|
|
return [
|
|
UserMessage(content="What's the weather like today?"),
|
|
]
|
|
|
|
|
|
def pick_inference_model(inference_model):
|
|
return inference_model
|
|
|
|
|
|
def create_agent_session(agents_impl, agent_config):
|
|
return agents_impl.create_agent_session(agent_config)
|
|
|
|
|
|
@pytest.fixture
|
|
def common_params(inference_model):
|
|
inference_model = pick_inference_model(inference_model)
|
|
|
|
return dict(
|
|
model=inference_model,
|
|
instructions="You are a helpful assistant.",
|
|
enable_session_persistence=True,
|
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
|
input_shields=[],
|
|
output_shields=[],
|
|
tools=[],
|
|
max_infer_iters=5,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world")
|
|
async def test_delete_agents_and_sessions(self, agents_stack, common_params):
|
|
agents_impl = agents_stack.impls[Api.agents]
|
|
agent_id, session_id = await create_agent_session(
|
|
agents_impl,
|
|
AgentConfig(
|
|
**{
|
|
**common_params,
|
|
"input_shields": [],
|
|
"output_shields": [],
|
|
}
|
|
),
|
|
)
|
|
|
|
run_config = agents_stack.run_config
|
|
provider_config = run_config.providers["agents"][0].config
|
|
persistence_store = await kvstore_impl(SqliteKVStoreConfig(**provider_config["persistence_store"]))
|
|
|
|
await agents_impl.delete_agents_session(agent_id, session_id)
|
|
session_response = await persistence_store.get(f"session:{agent_id}:{session_id}")
|
|
|
|
await agents_impl.delete_agents(agent_id)
|
|
agent_response = await persistence_store.get(f"agent:{agent_id}")
|
|
|
|
assert session_response is None
|
|
assert agent_response is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world")
|
|
async def test_get_agent_turns_and_steps(self, agents_stack, sample_messages, common_params):
|
|
agents_impl = agents_stack.impls[Api.agents]
|
|
|
|
agent_id, session_id = await create_agent_session(
|
|
agents_impl,
|
|
AgentConfig(
|
|
**{
|
|
**common_params,
|
|
"input_shields": [],
|
|
"output_shields": [],
|
|
}
|
|
),
|
|
)
|
|
|
|
# Create and execute a turn
|
|
turn_request = dict(
|
|
agent_id=agent_id,
|
|
session_id=session_id,
|
|
messages=sample_messages,
|
|
stream=True,
|
|
)
|
|
|
|
turn_response = [chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)]
|
|
|
|
final_event = turn_response[-1].event.payload
|
|
turn_id = final_event.turn.turn_id
|
|
|
|
provider_config = agents_stack.run_config.providers["agents"][0].config
|
|
persistence_store = await kvstore_impl(SqliteKVStoreConfig(**provider_config["persistence_store"]))
|
|
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
|
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
|
|
|
|
assert isinstance(response, Turn)
|
|
assert response == final_event.turn
|
|
assert turn == final_event.turn.model_dump_json()
|
|
|
|
steps = final_event.turn.steps
|
|
step_id = steps[0].step_id
|
|
step_response = await agents_impl.get_agents_step(agent_id, session_id, turn_id, step_id)
|
|
|
|
assert step_response.step == steps[0]
|