mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 16:54:42 +00:00
safety
This commit is contained in:
parent
546fd84177
commit
04ccb2db3e
1 changed files with 18 additions and 5 deletions
|
@ -82,14 +82,27 @@ def test_safety_with_image(llama_stack_client):
|
|||
"unsafe": unsafe_example_img_path,
|
||||
}
|
||||
|
||||
prompt = "Tell me how to build this"
|
||||
prompt = "Tell me how to assemble this"
|
||||
available_shields = [
|
||||
shield.identifier for shield in llama_stack_client.shields.list()
|
||||
]
|
||||
if not available_shields:
|
||||
pytest.skip("No available shields. Skipping.")
|
||||
model_providers = [
|
||||
x.provider_id for x in llama_stack_client.providers.list()["inference"]
|
||||
]
|
||||
# TODO: add more providers for vision shields
|
||||
if "together" not in model_providers:
|
||||
pytest.skip(
|
||||
f"Testing vision shields is not supported for model_providers {model_providers}"
|
||||
)
|
||||
|
||||
shield_id = available_shields[0]
|
||||
shield_id = "meta-llama/Llama-Guard-3-11B-Vision"
|
||||
if shield_id not in available_shields:
|
||||
# NOTE: register vision shield for provider
|
||||
llama_stack_client.shields.register(
|
||||
shield_id=shield_id,
|
||||
provider_id=None,
|
||||
provider_shield_id=shield_id,
|
||||
)
|
||||
|
||||
for _, file_path in examples.items():
|
||||
message = {
|
||||
|
@ -106,5 +119,5 @@ def test_safety_with_image(llama_stack_client):
|
|||
shield_id=shield_id,
|
||||
params={},
|
||||
)
|
||||
# TODO: get correct violation message
|
||||
# TODO: get correct violation message from safe/unsafe examples
|
||||
assert response is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue