mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
load models using hf model id (#108)
This commit is contained in:
parent
e73e9110b7
commit
3ae1597b9b
1 changed files with 5 additions and 2 deletions
|
@ -14,11 +14,12 @@ from transformers import (
|
|||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
MllamaForConditionalGeneration,
|
||||
MllamaProcessor,
|
||||
MllamaProcessor
|
||||
)
|
||||
|
||||
|
||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||
from llama_models.llama3.api.datatypes import Message, Role
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
SAFE_RESPONSE = "safe"
|
||||
|
@ -146,6 +147,8 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
self.model_dir = f"meta-llama/{self.get_model_name()}"
|
||||
|
||||
if self.is_lg_vision():
|
||||
|
||||
self.model = MllamaForConditionalGeneration.from_pretrained(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue