mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
load models using hf model id (#108)
This commit is contained in:
parent
e73e9110b7
commit
3ae1597b9b
1 changed files with 5 additions and 2 deletions
|
@ -14,11 +14,12 @@ from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
MllamaForConditionalGeneration,
|
MllamaForConditionalGeneration,
|
||||||
MllamaProcessor,
|
MllamaProcessor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||||
from llama_models.llama3.api.datatypes import Message, Role
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
SAFE_RESPONSE = "safe"
|
SAFE_RESPONSE = "safe"
|
||||||
|
@ -146,6 +147,8 @@ class LlamaGuardShield(ShieldBase):
|
||||||
|
|
||||||
torch_dtype = torch.bfloat16
|
torch_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
self.model_dir = f"meta-llama/{self.get_model_name()}"
|
||||||
|
|
||||||
if self.is_lg_vision():
|
if self.is_lg_vision():
|
||||||
|
|
||||||
self.model = MllamaForConditionalGeneration.from_pretrained(
|
self.model = MllamaForConditionalGeneration.from_pretrained(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue