From 7b9a6eda635cc8b75c2b0aabb078d1b7da96d3b8 Mon Sep 17 00:00:00 2001 From: Michael Dawson Date: Mon, 26 May 2025 13:40:35 -0400 Subject: [PATCH] squash: address comments Signed-off-by: Michael Dawson --- .../providers/inline/safety/prompt_guard/config.py | 14 -------------- .../inline/safety/prompt_guard/prompt_guard.py | 4 +++- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/llama_stack/providers/inline/safety/prompt_guard/config.py b/llama_stack/providers/inline/safety/prompt_guard/config.py index 78cb17733..69ea512c5 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/config.py +++ b/llama_stack/providers/inline/safety/prompt_guard/config.py @@ -15,14 +15,8 @@ class PromptGuardType(Enum): jailbreak = "jailbreak" -class PromptGuardExecutionType(Enum): - cpu = "cpu" - cuda = "cuda" - - class PromptGuardConfig(BaseModel): guard_type: str = PromptGuardType.injection.value - guard_execution_type: str = PromptGuardExecutionType.cuda.value @classmethod @field_validator("guard_type") @@ -31,16 +25,8 @@ class PromptGuardConfig(BaseModel): raise ValueError(f"Unknown prompt guard type: {v}") return v - @classmethod - @field_validator("guard_execution_type") - def validate_guard_execution_type(cls, v): - if v not in [t.value for t in PromptGuardExecutionType]: - raise ValueError(f"Unknown prompt guard execution type: {v}") - return v - @classmethod def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { "guard_type": "injection", - "guard_execution_type": "cuda", } diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 77fb28ce8..ff87889ea 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -75,7 +75,9 @@ class PromptGuardShield: self.temperature = temperature self.threshold = threshold - self.device = self.config.guard_execution_type + self.device = "cpu" + if torch.cuda.is_available(): + self.device = "cuda" # load model and tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_dir)