load models using hf model id (#108)

This commit is contained in:
Kate Plawiak 2024-09-25 18:40:09 -07:00 committed by GitHub
parent e73e9110b7
commit 3ae1597b9b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -14,11 +14,12 @@ from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
MllamaForConditionalGeneration, MllamaForConditionalGeneration,
MllamaProcessor, MllamaProcessor
) )
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_models.llama3.api.datatypes import Message, Role from llama_models.llama3.api.datatypes import * # noqa: F403
SAFE_RESPONSE = "safe" SAFE_RESPONSE = "safe"
@ -146,6 +147,8 @@ class LlamaGuardShield(ShieldBase):
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
self.model_dir = f"meta-llama/{self.get_model_name()}"
if self.is_lg_vision(): if self.is_lg_vision():
self.model = MllamaForConditionalGeneration.from_pretrained( self.model = MllamaForConditionalGeneration.from_pretrained(