Refactor safety shield fixtures

This commit is contained in:
Ashwin Bharambe 2024-11-12 15:58:58 -08:00
parent 743da9690b
commit 8645f8bc9e
6 changed files with 53 additions and 62 deletions

View file

@ -27,19 +27,38 @@ def safety_remote() -> ProviderFixture:
return remote_stack_fixture()
def safety_model_from_shield(shield_id):
if shield_id in ("Bedrock", "CodeScanner", "CodeShield"):
return None
return shield_id
@pytest.fixture(scope="session")
def safety_model(request):
def safety_shield(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--safety-model", None)
shield_id = request.param
else:
shield_id = request.config.getoption("--safety-shield", None)
if shield_id == "bedrock":
shield_id = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
else:
params = {}
return ShieldInput(
shield_id=shield_id,
params=params,
)
@pytest.fixture(scope="session")
def safety_llama_guard(safety_model) -> ProviderFixture:
def safety_llama_guard() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="inline::llama-guard",
provider_id="llama-guard",
provider_type="inline::llama-guard",
config=LlamaGuardConfig().model_dump(),
)
@ -55,7 +74,7 @@ def safety_prompt_guard() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="inline::prompt-guard",
provider_id="prompt-guard",
provider_type="inline::prompt-guard",
config=PromptGuardConfig().model_dump(),
)
@ -80,7 +99,7 @@ SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
@pytest_asyncio.fixture(scope="session")
async def safety_stack(inference_model, safety_model, request):
async def safety_stack(inference_model, safety_shield, request):
# We need an inference + safety fixture to test safety
fixture_dict = request.param
inference_fixture = request.getfixturevalue(
@ -98,32 +117,13 @@ async def safety_stack(inference_model, safety_model, request):
if safety_fixture.provider_data:
provider_data.update(safety_fixture.provider_data)
shield_provider_type = safety_fixture.providers[0].provider_type
shield_input = get_shield_to_register(shield_provider_type, safety_model)
print(f"inference_model: {inference_model}")
print(f"shield_input = {shield_input}")
impls = await resolve_impls_for_test_v2(
[Api.safety, Api.shields, Api.inference],
providers,
provider_data,
models=[ModelInput(model_id=inference_model)],
shields=[shield_input],
shields=[safety_shield],
)
shield = await impls[Api.shields].get_shield(shield_input.shield_id)
shield = await impls[Api.shields].get_shield(safety_shield.shield_id)
return impls[Api.safety], impls[Api.shields], shield
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
if provider_type == "remote::bedrock":
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
else:
params = {}
identifier = safety_model
return ShieldInput(
shield_id=identifier,
params=params,
)