From 3ae1597b9bf9e6f643cbec9abbc6fb02fb015c04 Mon Sep 17 00:00:00 2001 From: Kate Plawiak <113949869+kplawiak@users.noreply.github.com> Date: Wed, 25 Sep 2024 18:40:09 -0700 Subject: [PATCH] load models using hf model id (#108) --- .../impls/meta_reference/safety/shields/llama_guard.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py index b911be76a..5ee562179 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py @@ -14,11 +14,12 @@ from transformers import ( AutoModelForCausalLM, AutoTokenizer, MllamaForConditionalGeneration, - MllamaProcessor, + MllamaProcessor ) + from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse -from llama_models.llama3.api.datatypes import Message, Role +from llama_models.llama3.api.datatypes import * # noqa: F403 SAFE_RESPONSE = "safe" @@ -146,6 +147,8 @@ class LlamaGuardShield(ShieldBase): torch_dtype = torch.bfloat16 + self.model_dir = f"meta-llama/{self.get_model_name()}" + if self.is_lg_vision(): self.model = MllamaForConditionalGeneration.from_pretrained(