load models using hf model id

This commit is contained in:
Kate Plawiak 2024-09-25 18:32:15 -07:00
parent 851c30597a
commit 2f98089f88

View file

@ -14,11 +14,12 @@ from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
MllamaForConditionalGeneration,
MllamaProcessor,
MllamaProcessor
)
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_models.llama3.api.datatypes import Message, Role
from llama_models.llama3.api.datatypes import * # noqa: F403
SAFE_RESPONSE = "safe"
@ -146,6 +147,8 @@ class LlamaGuardShield(ShieldBase):
torch_dtype = torch.bfloat16
self.model_dir = f"meta-llama/{self.get_model_name()}"
if self.is_lg_vision():
self.model = MllamaForConditionalGeneration.from_pretrained(