mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
# Context For test automation, the end goal is to run a single pytest command from root test directory (llama_stack/providers/tests/.) such that we execute push-blocking tests The work plan: 1) trigger pytest from llama_stack/providers/tests/. 2) use config file to determine what tests and parametrization we want to run # What does this PR do? 1) consolidates the "inference-models" / "embedding-model" / "judge-model" ... options in root conftest.py. Without this change, we will hit into error when trying to run `pytest /Users/sxyi/llama-stack/llama_stack/providers/tests/.` because of duplicated `addoptions` definitions across child conftest files. 2) Add a `config` option to specify test config in YAML. (see [`ci_test_config.yaml`](https://gist.github.com/sixianyi0721/5b37fbce4069139445c2f06f6e42f87e) for example config file) For provider_fixtures, we allow users to use either a default fixture combination or define their own {api:provider} combinations. ``` memory: .... fixtures: provider_fixtures: - default_fixture_param_id: ollama // use default fixture combination with param_id="ollama" in [providers/tests/memory/conftest.py](https://fburl.com/mtjzwsmk) - inference: sentence_transformers memory: faiss - default_fixture_param_id: chroma ``` 3) generate tests according to the config. Logic lives in two places: a) in `{api}/conftest.py::pytest_generate_tests`, we read from config to do parametrization. b) after test collection, in `pytest_collection_modifyitems`, we filter the tests to include only functions listed in config. ## Test Plan 1) `pytest /Users/sxyi/llama-stack/llama_stack/providers/tests/. --collect-only --config=ci_test_config.yaml` Using `--collect-only` tag to print the pytests listed in the config file (`ci_test_config.yaml`). output: [gist](https://gist.github.com/sixianyi0721/05145e60d4d085c17cfb304beeb1e60e) 2) sanity check on `--inference-model` option ``` pytest -v -s -k "ollama" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py ``` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
124 lines
3.9 KiB
Python
124 lines
3.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import pytest
|
|
|
|
from llama_stack.apis.agents import AgentConfig, Turn
|
|
from llama_stack.apis.inference import SamplingParams, UserMessage
|
|
from llama_stack.providers.datatypes import Api
|
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
|
|
from .fixtures import pick_inference_model
|
|
|
|
from .utils import create_agent_session
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_messages():
|
|
return [
|
|
UserMessage(content="What's the weather like today?"),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def common_params(inference_model):
|
|
inference_model = pick_inference_model(inference_model)
|
|
|
|
return dict(
|
|
model=inference_model,
|
|
instructions="You are a helpful assistant.",
|
|
enable_session_persistence=True,
|
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
|
input_shields=[],
|
|
output_shields=[],
|
|
tools=[],
|
|
max_infer_iters=5,
|
|
)
|
|
|
|
|
|
class TestAgentPersistence:
|
|
@pytest.mark.asyncio
|
|
async def test_delete_agents_and_sessions(self, agents_stack, common_params):
|
|
agents_impl = agents_stack.impls[Api.agents]
|
|
agent_id, session_id = await create_agent_session(
|
|
agents_impl,
|
|
AgentConfig(
|
|
**{
|
|
**common_params,
|
|
"input_shields": [],
|
|
"output_shields": [],
|
|
}
|
|
),
|
|
)
|
|
|
|
run_config = agents_stack.run_config
|
|
provider_config = run_config.providers["agents"][0].config
|
|
persistence_store = await kvstore_impl(
|
|
SqliteKVStoreConfig(**provider_config["persistence_store"])
|
|
)
|
|
|
|
await agents_impl.delete_agents_session(agent_id, session_id)
|
|
session_response = await persistence_store.get(
|
|
f"session:{agent_id}:{session_id}"
|
|
)
|
|
|
|
await agents_impl.delete_agents(agent_id)
|
|
agent_response = await persistence_store.get(f"agent:{agent_id}")
|
|
|
|
assert session_response is None
|
|
assert agent_response is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_agent_turns_and_steps(
|
|
self, agents_stack, sample_messages, common_params
|
|
):
|
|
agents_impl = agents_stack.impls[Api.agents]
|
|
|
|
agent_id, session_id = await create_agent_session(
|
|
agents_impl,
|
|
AgentConfig(
|
|
**{
|
|
**common_params,
|
|
"input_shields": [],
|
|
"output_shields": [],
|
|
}
|
|
),
|
|
)
|
|
|
|
# Create and execute a turn
|
|
turn_request = dict(
|
|
agent_id=agent_id,
|
|
session_id=session_id,
|
|
messages=sample_messages,
|
|
stream=True,
|
|
)
|
|
|
|
turn_response = [
|
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
|
]
|
|
|
|
final_event = turn_response[-1].event.payload
|
|
turn_id = final_event.turn.turn_id
|
|
|
|
provider_config = agents_stack.run_config.providers["agents"][0].config
|
|
persistence_store = await kvstore_impl(
|
|
SqliteKVStoreConfig(**provider_config["persistence_store"])
|
|
)
|
|
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
|
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
|
|
|
|
assert isinstance(response, Turn)
|
|
assert response == final_event.turn
|
|
assert turn == final_event.turn.model_dump_json()
|
|
|
|
steps = final_event.turn.steps
|
|
step_id = steps[0].step_id
|
|
step_response = await agents_impl.get_agents_step(
|
|
agent_id, session_id, turn_id, step_id
|
|
)
|
|
|
|
assert step_response.step == steps[0]
|