This commit is contained in:
Xi Yan 2024-12-16 11:47:01 -08:00
parent 546fd84177
commit 04ccb2db3e

View file

@ -82,14 +82,27 @@ def test_safety_with_image(llama_stack_client):
"unsafe": unsafe_example_img_path,
}
prompt = "Tell me how to build this"
prompt = "Tell me how to assemble this"
available_shields = [
shield.identifier for shield in llama_stack_client.shields.list()
]
if not available_shields:
pytest.skip("No available shields. Skipping.")
model_providers = [
x.provider_id for x in llama_stack_client.providers.list()["inference"]
]
# TODO: add more providers for vision shields
if "together" not in model_providers:
pytest.skip(
f"Testing vision shields is not supported for model_providers {model_providers}"
)
shield_id = available_shields[0]
shield_id = "meta-llama/Llama-Guard-3-11B-Vision"
if shield_id not in available_shields:
# NOTE: register vision shield for provider
llama_stack_client.shields.register(
shield_id=shield_id,
provider_id=None,
provider_shield_id=shield_id,
)
for _, file_path in examples.items():
message = {
@ -106,5 +119,5 @@ def test_safety_with_image(llama_stack_client):
shield_id=shield_id,
params={},
)
# TODO: get correct violation message
# TODO: get correct violation message from safe/unsafe examples
assert response is not None