diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py index b911be76a..5ee562179 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py @@ -14,11 +14,12 @@ from transformers import ( AutoModelForCausalLM, AutoTokenizer, MllamaForConditionalGeneration, - MllamaProcessor, + MllamaProcessor ) + from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse -from llama_models.llama3.api.datatypes import Message, Role +from llama_models.llama3.api.datatypes import * # noqa: F403 SAFE_RESPONSE = "safe" @@ -146,6 +147,8 @@ class LlamaGuardShield(ShieldBase): torch_dtype = torch.bfloat16 + self.model_dir = f"meta-llama/{self.get_model_name()}" + if self.is_lg_vision(): self.model = MllamaForConditionalGeneration.from_pretrained(