From f316dffe800de6e599fcdb355c5b43104d300c09 Mon Sep 17 00:00:00 2001 From: Michael Dawson Date: Fri, 16 May 2025 14:26:44 -0400 Subject: [PATCH] feat: add cpu/cuda config for prompt guard Previously prompt guard was hard coded to require cuda which prevented it from being used on an instance without a cuda support. This PR allows prompt guard to be configured to use either cpu or cuda. Signed-off-by: Michael Dawson --- .../providers/inline/safety/prompt_guard/config.py | 14 ++++++++++++++ .../inline/safety/prompt_guard/prompt_guard.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/safety/prompt_guard/config.py b/llama_stack/providers/inline/safety/prompt_guard/config.py index 69ea512c5..78cb17733 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/config.py +++ b/llama_stack/providers/inline/safety/prompt_guard/config.py @@ -15,8 +15,14 @@ class PromptGuardType(Enum): jailbreak = "jailbreak" +class PromptGuardExecutionType(Enum): + cpu = "cpu" + cuda = "cuda" + + class PromptGuardConfig(BaseModel): guard_type: str = PromptGuardType.injection.value + guard_execution_type: str = PromptGuardExecutionType.cuda.value @classmethod @field_validator("guard_type") @@ -25,8 +31,16 @@ class PromptGuardConfig(BaseModel): raise ValueError(f"Unknown prompt guard type: {v}") return v + @classmethod + @field_validator("guard_execution_type") + def validate_guard_execution_type(cls, v): + if v not in [t.value for t in PromptGuardExecutionType]: + raise ValueError(f"Unknown prompt guard execution type: {v}") + return v + @classmethod def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { "guard_type": "injection", + "guard_execution_type": "cuda", } diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 56ce8285f..77fb28ce8 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -75,7 +75,7 @@ class PromptGuardShield: self.temperature = temperature self.threshold = threshold - self.device = "cuda" + self.device = self.config.guard_execution_type # load model and tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_dir)