From 82f420c4f0200225ef871a1a609c05e4f6e3ca54 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 25 Sep 2024 11:30:27 -0700 Subject: [PATCH] fix safety using inference (#99) --- .../impls/meta_reference/safety/shields/llama_guard.py | 1 + llama_stack/providers/registry/safety.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py index e8c7b3560..b911be76a 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py @@ -18,6 +18,7 @@ from transformers import ( ) from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse +from llama_models.llama3.api.datatypes import Message, Role SAFE_RESPONSE = "safe" diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 1f353912b..ac14eaaac 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -28,6 +28,9 @@ def available_providers() -> List[ProviderSpec]: ], module="llama_stack.providers.impls.meta_reference.safety", config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig", + api_dependencies=[ + Api.inference, + ], ), remote_provider_spec( api=Api.safety,