mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 17:14:30 +00:00
Continues the refactor of tests.
Tests from `providers/tests` should be considered deprecated. For this
PR, I deleted most of the tests in
- inference
- safety
- agents
since much more comprehensive tests exist in
`tests/integration/{inference,safety,agents}` already.
I moved `test_persistence.py` from agents, but disabled all the tests
since that test needs to be properly migrated.
## Test Plan
```
LLAMA_STACK_CONFIG=fireworks pytest -s -v agents --vision-inference-model=''
/Users/ashwin/homebrew/Caskroom/miniconda/base/envs/toolchain/lib/python3.10/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"
warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
======================================================================================================= test session starts ========================================================================================================
platform darwin -- Python 3.10.16, pytest-8.3.3, pluggy-1.5.0 -- /Users/ashwin/homebrew/Caskroom/miniconda/base/envs/toolchain/bin/python
cachedir: .pytest_cache
metadata: {'Python': '3.10.16', 'Platform': 'macOS-15.3.1-arm64-arm-64bit', 'Packages': {'pytest': '8.3.3', 'pluggy': '1.5.0'}, 'Plugins': {'asyncio': '0.24.0', 'html': '4.1.1', 'metadata': '3.1.1', 'anyio': '4.8.0', 'nbval': '0.11.0'}}
rootdir: /Users/ashwin/local/llama-stack
configfile: pyproject.toml
plugins: asyncio-0.24.0, html-4.1.1, metadata-3.1.1, anyio-4.8.0, nbval-0.11.0
asyncio: mode=strict, default_loop_scope=None
collected 15 items
agents/test_agents.py::test_agent_simple[txt=8B] PASSED
agents/test_agents.py::test_tool_config[txt=8B] PASSED
agents/test_agents.py::test_builtin_tool_web_search[txt=8B] PASSED
agents/test_agents.py::test_builtin_tool_code_execution[txt=8B] PASSED
agents/test_agents.py::test_code_interpreter_for_attachments[txt=8B] PASSED
agents/test_agents.py::test_custom_tool[txt=8B] PASSED
agents/test_agents.py::test_custom_tool_infinite_loop[txt=8B] PASSED
agents/test_agents.py::test_tool_choice[txt=8B] PASSED
agents/test_agents.py::test_rag_agent[txt=8B-builtin::rag/knowledge_search] PASSED
agents/test_agents.py::test_rag_agent[txt=8B-builtin::rag] PASSED
agents/test_agents.py::test_rag_agent_with_attachments[txt=8B] PASSED
agents/test_agents.py::test_rag_and_code_agent[txt=8B] PASSED
agents/test_agents.py::test_create_turn_response[txt=8B] PASSED
agents/test_persistence.py::test_delete_agents_and_sessions SKIPPED (This test needs to be migrated to api / client-sdk world)
agents/test_persistence.py::test_get_agent_turns_and_steps SKIPPED (This test needs to be migrated to api / client-sdk world)
```
137 lines
4 KiB
Python
137 lines
4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from llama_stack.apis.models import ModelInput
|
|
from llama_stack.apis.shields import ShieldInput
|
|
from llama_stack.distribution.datatypes import Api, Provider
|
|
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
|
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
|
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
|
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
|
|
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
|
from ..env import get_env_or_fail
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def safety_remote() -> ProviderFixture:
|
|
return remote_stack_fixture()
|
|
|
|
|
|
def safety_model_from_shield(shield_id):
|
|
if shield_id in ("Bedrock", "CodeScanner", "CodeShield"):
|
|
return None
|
|
|
|
return shield_id
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def safety_shield(request):
|
|
if hasattr(request, "param"):
|
|
shield_id = request.param
|
|
else:
|
|
shield_id = request.config.getoption("--safety-shield", None)
|
|
|
|
if shield_id == "bedrock":
|
|
shield_id = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
|
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
|
|
else:
|
|
params = {}
|
|
|
|
if not shield_id:
|
|
return None
|
|
|
|
return ShieldInput(
|
|
shield_id=shield_id,
|
|
params=params,
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def safety_llama_guard() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="llama-guard",
|
|
provider_type="inline::llama-guard",
|
|
config=LlamaGuardConfig().model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
# TODO: this is not tested yet; we would need to configure the run_shield() test
|
|
# and parametrize it with the "prompt" for testing depending on the safety fixture
|
|
# we are using.
|
|
@pytest.fixture(scope="session")
|
|
def safety_prompt_guard() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="prompt-guard",
|
|
provider_type="inline::prompt-guard",
|
|
config=PromptGuardConfig().model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def safety_bedrock() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="bedrock",
|
|
provider_type="remote::bedrock",
|
|
config=BedrockSafetyConfig().model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def safety_nvidia() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="nvidia",
|
|
provider_type="remote::nvidia",
|
|
config=NVIDIASafetyConfig().model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote", "nvidia"]
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="session")
|
|
async def safety_stack(inference_model, safety_shield, request):
|
|
# We need an inference + safety fixture to test safety
|
|
fixture_dict = request.param
|
|
|
|
providers = {}
|
|
provider_data = {}
|
|
for key in ["inference", "safety"]:
|
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
|
providers[key] = fixture.providers
|
|
if fixture.provider_data:
|
|
provider_data.update(fixture.provider_data)
|
|
|
|
test_stack = await construct_stack_for_test(
|
|
[Api.safety, Api.shields, Api.inference],
|
|
providers,
|
|
provider_data,
|
|
models=[ModelInput(model_id=inference_model)],
|
|
shields=[safety_shield],
|
|
)
|
|
|
|
shield = await test_stack.impls[Api.shields].get_shield(safety_shield.shield_id)
|
|
return test_stack.impls[Api.safety], test_stack.impls[Api.shields], shield
|