load models using hf model id (#108)

This commit is contained in:
Kate Plawiak 2024-09-25 18:40:09 -07:00 committed by GitHub
parent e73e9110b7
commit 3ae1597b9b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -14,11 +14,12 @@ from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
MllamaForConditionalGeneration,
MllamaProcessor,
MllamaProcessor
)
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_models.llama3.api.datatypes import Message, Role
from llama_models.llama3.api.datatypes import * # noqa: F403
SAFE_RESPONSE = "safe"
@ -146,6 +147,8 @@ class LlamaGuardShield(ShieldBase):
torch_dtype = torch.bfloat16
self.model_dir = f"meta-llama/{self.get_model_name()}"
if self.is_lg_vision():
self.model = MllamaForConditionalGeneration.from_pretrained(