mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-18 23:28:53 +00:00
Refactor safety shield fixtures
This commit is contained in:
parent
743da9690b
commit
8645f8bc9e
6 changed files with 53 additions and 62 deletions
|
@ -10,7 +10,7 @@ from ..conftest import get_provider_fixture_overrides
|
|||
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..memory.fixtures import MEMORY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||
from .fixtures import AGENTS_FIXTURES
|
||||
|
||||
|
||||
|
@ -75,28 +75,30 @@ def pytest_addoption(parser):
|
|||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-model",
|
||||
"--safety-shield",
|
||||
action="store",
|
||||
default="Llama-Guard-3-8B",
|
||||
help="Specify the safety model to use for testing",
|
||||
help="Specify the safety shield to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
safety_model = metafunc.config.getoption("--safety-model")
|
||||
if "safety_model" in metafunc.fixturenames:
|
||||
shield_id = metafunc.config.getoption("--safety-shield")
|
||||
if "safety_shield" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"safety_model",
|
||||
[pytest.param(safety_model, id="")],
|
||||
"safety_shield",
|
||||
[pytest.param(shield_id, id="")],
|
||||
indirect=True,
|
||||
)
|
||||
if "inference_model" in metafunc.fixturenames:
|
||||
inference_model = metafunc.config.getoption("--inference-model")
|
||||
models = list(set({inference_model, safety_model}))
|
||||
models = set({inference_model})
|
||||
if safety_model := safety_model_from_shield(shield_id):
|
||||
models.add(safety_model)
|
||||
|
||||
metafunc.parametrize(
|
||||
"inference_model",
|
||||
[pytest.param(models, id="")],
|
||||
[pytest.param(list(models), id="")],
|
||||
indirect=True,
|
||||
)
|
||||
if "agents_stack" in metafunc.fixturenames:
|
||||
|
|
|
@ -19,7 +19,6 @@ from llama_stack.providers.inline.agents.meta_reference import (
|
|||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..safety.fixtures import get_shield_to_register
|
||||
|
||||
|
||||
def pick_inference_model(inference_model):
|
||||
|
@ -60,7 +59,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def agents_stack(request, inference_model, safety_model):
|
||||
async def agents_stack(request, inference_model, safety_shield):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
|
@ -71,9 +70,6 @@ async def agents_stack(request, inference_model, safety_model):
|
|||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
shield_input = get_shield_to_register(
|
||||
providers["safety"][0].provider_type, safety_model
|
||||
)
|
||||
inference_models = (
|
||||
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||
)
|
||||
|
@ -87,6 +83,6 @@ async def agents_stack(request, inference_model, safety_model):
|
|||
)
|
||||
for model in inference_models
|
||||
],
|
||||
shields=[shield_input],
|
||||
shields=[safety_shield],
|
||||
)
|
||||
return impls[Api.agents], impls[Api.memory]
|
||||
|
|
|
@ -82,7 +82,7 @@ async def create_agent_session(agents_impl, agent_config):
|
|||
class TestAgents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_turns_with_safety(
|
||||
self, safety_model, agents_stack, common_params
|
||||
self, safety_shield, agents_stack, common_params
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agent_id, session_id = await create_agent_session(
|
||||
|
@ -90,8 +90,8 @@ class TestAgents:
|
|||
AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"input_shields": [safety_model],
|
||||
"output_shields": [safety_model],
|
||||
"input_shields": [safety_shield.shield_id],
|
||||
"output_shields": [safety_shield.shield_id],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue