diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py index e8c7b3560..b911be76a 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py @@ -18,6 +18,7 @@ from transformers import ( ) from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse +from llama_models.llama3.api.datatypes import Message, Role SAFE_RESPONSE = "safe" diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 1f353912b..ac14eaaac 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -28,6 +28,9 @@ def available_providers() -> List[ProviderSpec]: ], module="llama_stack.providers.impls.meta_reference.safety", config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig", + api_dependencies=[ + Api.inference, + ], ), remote_provider_spec( api=Api.safety,