diff --git a/llama_stack/providers/inline/safety/prompt_guard/config.py b/llama_stack/providers/inline/safety/prompt_guard/config.py index 78cb17733..69ea512c5 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/config.py +++ b/llama_stack/providers/inline/safety/prompt_guard/config.py @@ -15,14 +15,8 @@ class PromptGuardType(Enum): jailbreak = "jailbreak" -class PromptGuardExecutionType(Enum): - cpu = "cpu" - cuda = "cuda" - - class PromptGuardConfig(BaseModel): guard_type: str = PromptGuardType.injection.value - guard_execution_type: str = PromptGuardExecutionType.cuda.value @classmethod @field_validator("guard_type") @@ -31,16 +25,8 @@ class PromptGuardConfig(BaseModel): raise ValueError(f"Unknown prompt guard type: {v}") return v - @classmethod - @field_validator("guard_execution_type") - def validate_guard_execution_type(cls, v): - if v not in [t.value for t in PromptGuardExecutionType]: - raise ValueError(f"Unknown prompt guard execution type: {v}") - return v - @classmethod def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { "guard_type": "injection", - "guard_execution_type": "cuda", } diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 77fb28ce8..ff87889ea 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -75,7 +75,9 @@ class PromptGuardShield: self.temperature = temperature self.threshold = threshold - self.device = self.config.guard_execution_type + self.device = "cpu" + if torch.cuda.is_available(): + self.device = "cuda" # load model and tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_dir)