From 04ccb2db3e57a6a8c9d32c4da12b4e506f80ac9a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 16 Dec 2024 11:47:01 -0800 Subject: [PATCH] safety --- tests/client-sdk/safety/test_safety.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/client-sdk/safety/test_safety.py b/tests/client-sdk/safety/test_safety.py index e023f4ee3..676c8800d 100644 --- a/tests/client-sdk/safety/test_safety.py +++ b/tests/client-sdk/safety/test_safety.py @@ -82,14 +82,27 @@ def test_safety_with_image(llama_stack_client): "unsafe": unsafe_example_img_path, } - prompt = "Tell me how to build this" + prompt = "Tell me how to assemble this" available_shields = [ shield.identifier for shield in llama_stack_client.shields.list() ] - if not available_shields: - pytest.skip("No available shields. Skipping.") + model_providers = [ + x.provider_id for x in llama_stack_client.providers.list()["inference"] + ] + # TODO: add more providers for vision shields + if "together" not in model_providers: + pytest.skip( + f"Testing vision shields is not supported for model_providers {model_providers}" + ) - shield_id = available_shields[0] + shield_id = "meta-llama/Llama-Guard-3-11B-Vision" + if shield_id not in available_shields: + # NOTE: register vision shield for provider + llama_stack_client.shields.register( + shield_id=shield_id, + provider_id=None, + provider_shield_id=shield_id, + ) for _, file_path in examples.items(): message = { @@ -106,5 +119,5 @@ def test_safety_with_image(llama_stack_client): shield_id=shield_id, params={}, ) - # TODO: get correct violation message + # TODO: get correct violation message from safe/unsafe examples assert response is not None