diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 56ce8285f..ff87889ea 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -75,7 +75,9 @@ class PromptGuardShield: self.temperature = temperature self.threshold = threshold - self.device = "cuda" + self.device = "cpu" + if torch.cuda.is_available(): + self.device = "cuda" # load model and tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_dir)