Refactor safety shield fixtures

This commit is contained in:
Ashwin Bharambe 2024-11-12 15:58:58 -08:00
parent 743da9690b
commit 8645f8bc9e
6 changed files with 53 additions and 62 deletions

View file

@ -10,7 +10,7 @@ from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from .fixtures import AGENTS_FIXTURES
@ -75,28 +75,30 @@ def pytest_addoption(parser):
help="Specify the inference model to use for testing",
)
parser.addoption(
"--safety-model",
"--safety-shield",
action="store",
default="Llama-Guard-3-8B",
help="Specify the safety model to use for testing",
help="Specify the safety shield to use for testing",
)
def pytest_generate_tests(metafunc):
safety_model = metafunc.config.getoption("--safety-model")
if "safety_model" in metafunc.fixturenames:
shield_id = metafunc.config.getoption("--safety-shield")
if "safety_shield" in metafunc.fixturenames:
metafunc.parametrize(
"safety_model",
[pytest.param(safety_model, id="")],
"safety_shield",
[pytest.param(shield_id, id="")],
indirect=True,
)
if "inference_model" in metafunc.fixturenames:
inference_model = metafunc.config.getoption("--inference-model")
models = list(set({inference_model, safety_model}))
models = set({inference_model})
if safety_model := safety_model_from_shield(shield_id):
models.add(safety_model)
metafunc.parametrize(
"inference_model",
[pytest.param(models, id="")],
[pytest.param(list(models), id="")],
indirect=True,
)
if "agents_stack" in metafunc.fixturenames:

View file

@ -19,7 +19,6 @@ from llama_stack.providers.inline.agents.meta_reference import (
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..safety.fixtures import get_shield_to_register
def pick_inference_model(inference_model):
@ -60,7 +59,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def agents_stack(request, inference_model, safety_model):
async def agents_stack(request, inference_model, safety_shield):
fixture_dict = request.param
providers = {}
@ -71,9 +70,6 @@ async def agents_stack(request, inference_model, safety_model):
if fixture.provider_data:
provider_data.update(fixture.provider_data)
shield_input = get_shield_to_register(
providers["safety"][0].provider_type, safety_model
)
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
@ -87,6 +83,6 @@ async def agents_stack(request, inference_model, safety_model):
)
for model in inference_models
],
shields=[shield_input],
shields=[safety_shield],
)
return impls[Api.agents], impls[Api.memory]

View file

@ -82,7 +82,7 @@ async def create_agent_session(agents_impl, agent_config):
class TestAgents:
@pytest.mark.asyncio
async def test_agent_turns_with_safety(
self, safety_model, agents_stack, common_params
self, safety_shield, agents_stack, common_params
):
agents_impl, _ = agents_stack
agent_id, session_id = await create_agent_session(
@ -90,8 +90,8 @@ class TestAgents:
AgentConfig(
**{
**common_params,
"input_shields": [safety_model],
"output_shields": [safety_model],
"input_shields": [safety_shield.shield_id],
"output_shields": [safety_shield.shield_id],
}
),
)