From 8645f8bc9e79f76c6951ef2e19b5d7080de21ad3 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 12 Nov 2024 15:58:58 -0800 Subject: [PATCH] Refactor safety shield fixtures --- .../providers/tests/agents/conftest.py | 20 ++++--- .../providers/tests/agents/fixtures.py | 8 +-- .../providers/tests/agents/test_agents.py | 6 +- .../providers/tests/safety/conftest.py | 18 +++--- .../providers/tests/safety/fixtures.py | 56 +++++++++---------- .../providers/tests/safety/test_safety.py | 7 --- 6 files changed, 53 insertions(+), 62 deletions(-) diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index aa3910b39..c4f766e26 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -10,7 +10,7 @@ from ..conftest import get_provider_fixture_overrides from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES -from ..safety.fixtures import SAFETY_FIXTURES +from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield from .fixtures import AGENTS_FIXTURES @@ -75,28 +75,30 @@ def pytest_addoption(parser): help="Specify the inference model to use for testing", ) parser.addoption( - "--safety-model", + "--safety-shield", action="store", default="Llama-Guard-3-8B", - help="Specify the safety model to use for testing", + help="Specify the safety shield to use for testing", ) def pytest_generate_tests(metafunc): - safety_model = metafunc.config.getoption("--safety-model") - if "safety_model" in metafunc.fixturenames: + shield_id = metafunc.config.getoption("--safety-shield") + if "safety_shield" in metafunc.fixturenames: metafunc.parametrize( - "safety_model", - [pytest.param(safety_model, id="")], + "safety_shield", + [pytest.param(shield_id, id="")], indirect=True, ) if "inference_model" in metafunc.fixturenames: inference_model = metafunc.config.getoption("--inference-model") - models = list(set({inference_model, safety_model})) + models = set({inference_model}) + if safety_model := safety_model_from_shield(shield_id): + models.add(safety_model) metafunc.parametrize( "inference_model", - [pytest.param(models, id="")], + [pytest.param(list(models), id="")], indirect=True, ) if "agents_stack" in metafunc.fixturenames: diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index db157174f..c58741f62 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -19,7 +19,6 @@ from llama_stack.providers.inline.agents.meta_reference import ( from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture -from ..safety.fixtures import get_shield_to_register def pick_inference_model(inference_model): @@ -60,7 +59,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"] @pytest_asyncio.fixture(scope="session") -async def agents_stack(request, inference_model, safety_model): +async def agents_stack(request, inference_model, safety_shield): fixture_dict = request.param providers = {} @@ -71,9 +70,6 @@ async def agents_stack(request, inference_model, safety_model): if fixture.provider_data: provider_data.update(fixture.provider_data) - shield_input = get_shield_to_register( - providers["safety"][0].provider_type, safety_model - ) inference_models = ( inference_model if isinstance(inference_model, list) else [inference_model] ) @@ -87,6 +83,6 @@ async def agents_stack(request, inference_model, safety_model): ) for model in inference_models ], - shields=[shield_input], + shields=[safety_shield], ) return impls[Api.agents], impls[Api.memory] diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 47e5a751f..bdfa38bc4 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -82,7 +82,7 @@ async def create_agent_session(agents_impl, agent_config): class TestAgents: @pytest.mark.asyncio async def test_agent_turns_with_safety( - self, safety_model, agents_stack, common_params + self, safety_shield, agents_stack, common_params ): agents_impl, _ = agents_stack agent_id, session_id = await create_agent_session( @@ -90,8 +90,8 @@ class TestAgents: AgentConfig( **{ **common_params, - "input_shields": [safety_model], - "output_shields": [safety_model], + "input_shields": [safety_shield.shield_id], + "output_shields": [safety_shield.shield_id], } ), ) diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index cb380ce57..76eb418ea 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -66,14 +66,14 @@ def pytest_configure(config): def pytest_addoption(parser): parser.addoption( - "--safety-model", + "--safety-shield", action="store", default=None, - help="Specify the safety model to use for testing", + help="Specify the safety shield to use for testing", ) -SAFETY_MODEL_PARAMS = [ +SAFETY_SHIELD_PARAMS = [ pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"), ] @@ -83,13 +83,13 @@ def pytest_generate_tests(metafunc): # But a user can also pass in a custom combination via the CLI by doing # `--providers inference=together,safety=meta_reference` - if "safety_model" in metafunc.fixturenames: - model = metafunc.config.getoption("--safety-model") - if model: - params = [pytest.param(model, id="")] + if "safety_shield" in metafunc.fixturenames: + shield_id = metafunc.config.getoption("--safety-shield") + if shield_id: + params = [pytest.param(shield_id, id="")] else: - params = SAFETY_MODEL_PARAMS - for fixture in ["inference_model", "safety_model"]: + params = SAFETY_SHIELD_PARAMS + for fixture in ["inference_model", "safety_shield"]: metafunc.parametrize( fixture, params, diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index b73c2d798..ade201b11 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -27,19 +27,38 @@ def safety_remote() -> ProviderFixture: return remote_stack_fixture() +def safety_model_from_shield(shield_id): + if shield_id in ("Bedrock", "CodeScanner", "CodeShield"): + return None + + return shield_id + + @pytest.fixture(scope="session") -def safety_model(request): +def safety_shield(request): if hasattr(request, "param"): - return request.param - return request.config.getoption("--safety-model", None) + shield_id = request.param + else: + shield_id = request.config.getoption("--safety-shield", None) + + if shield_id == "bedrock": + shield_id = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER") + params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")} + else: + params = {} + + return ShieldInput( + shield_id=shield_id, + params=params, + ) @pytest.fixture(scope="session") -def safety_llama_guard(safety_model) -> ProviderFixture: +def safety_llama_guard() -> ProviderFixture: return ProviderFixture( providers=[ Provider( - provider_id="inline::llama-guard", + provider_id="llama-guard", provider_type="inline::llama-guard", config=LlamaGuardConfig().model_dump(), ) @@ -55,7 +74,7 @@ def safety_prompt_guard() -> ProviderFixture: return ProviderFixture( providers=[ Provider( - provider_id="inline::prompt-guard", + provider_id="prompt-guard", provider_type="inline::prompt-guard", config=PromptGuardConfig().model_dump(), ) @@ -80,7 +99,7 @@ SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"] @pytest_asyncio.fixture(scope="session") -async def safety_stack(inference_model, safety_model, request): +async def safety_stack(inference_model, safety_shield, request): # We need an inference + safety fixture to test safety fixture_dict = request.param inference_fixture = request.getfixturevalue( @@ -98,32 +117,13 @@ async def safety_stack(inference_model, safety_model, request): if safety_fixture.provider_data: provider_data.update(safety_fixture.provider_data) - shield_provider_type = safety_fixture.providers[0].provider_type - shield_input = get_shield_to_register(shield_provider_type, safety_model) - - print(f"inference_model: {inference_model}") - print(f"shield_input = {shield_input}") impls = await resolve_impls_for_test_v2( [Api.safety, Api.shields, Api.inference], providers, provider_data, models=[ModelInput(model_id=inference_model)], - shields=[shield_input], + shields=[safety_shield], ) - shield = await impls[Api.shields].get_shield(shield_input.shield_id) + shield = await impls[Api.shields].get_shield(safety_shield.shield_id) return impls[Api.safety], impls[Api.shields], shield - - -def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput: - if provider_type == "remote::bedrock": - identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER") - params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")} - else: - params = {} - identifier = safety_model - - return ShieldInput( - shield_id=identifier, - params=params, - ) diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index 9daa7bf40..2b3e2d2f5 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -18,13 +18,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403 class TestSafety: - @pytest.mark.asyncio - async def test_new_shield(self, safety_stack): - _, shields_impl, shield = safety_stack - assert shield is not None - assert shield.provider_resource_id == shield.identifier - assert shield.provider_id is not None - @pytest.mark.asyncio async def test_shield_list(self, safety_stack): _, shields_impl, _ = safety_stack