mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
squash: address comments
Signed-off-by: Michael Dawson <mdawson@devrus.com>
This commit is contained in:
parent
f316dffe80
commit
7b9a6eda63
2 changed files with 3 additions and 15 deletions
|
@ -15,14 +15,8 @@ class PromptGuardType(Enum):
|
|||
jailbreak = "jailbreak"
|
||||
|
||||
|
||||
class PromptGuardExecutionType(Enum):
|
||||
cpu = "cpu"
|
||||
cuda = "cuda"
|
||||
|
||||
|
||||
class PromptGuardConfig(BaseModel):
|
||||
guard_type: str = PromptGuardType.injection.value
|
||||
guard_execution_type: str = PromptGuardExecutionType.cuda.value
|
||||
|
||||
@classmethod
|
||||
@field_validator("guard_type")
|
||||
|
@ -31,16 +25,8 @@ class PromptGuardConfig(BaseModel):
|
|||
raise ValueError(f"Unknown prompt guard type: {v}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
@field_validator("guard_execution_type")
|
||||
def validate_guard_execution_type(cls, v):
|
||||
if v not in [t.value for t in PromptGuardExecutionType]:
|
||||
raise ValueError(f"Unknown prompt guard execution type: {v}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"guard_type": "injection",
|
||||
"guard_execution_type": "cuda",
|
||||
}
|
||||
|
|
|
@ -75,7 +75,9 @@ class PromptGuardShield:
|
|||
self.temperature = temperature
|
||||
self.threshold = threshold
|
||||
|
||||
self.device = self.config.guard_execution_type
|
||||
self.device = "cpu"
|
||||
if torch.cuda.is_available():
|
||||
self.device = "cuda"
|
||||
|
||||
# load model and tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue