From b5b5f5b9ae2d85bf663d947f5a57616ec1bc52a6 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Mon, 11 Aug 2025 17:40:40 +0200 Subject: [PATCH] chore: add `mypy` prompt guard (#2678) # What does this PR do? This PR adds static type coverage to `llama-stack` Part of https://github.com/meta-llama/llama-stack/issues/2647 ## Test Plan Signed-off-by: Mustafa Elbehery --- .../inline/safety/prompt_guard/prompt_guard.py | 11 ++++++++--- pyproject.toml | 1 - 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 796771ee1..e11ec5cf5 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -15,6 +15,7 @@ from llama_stack.apis.safety import ( RunShieldResponse, Safety, SafetyViolation, + ShieldStore, ViolationLevel, ) from llama_stack.apis.shields import Shield @@ -32,6 +33,8 @@ PROMPT_GUARD_MODEL = "Prompt-Guard-86M" class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): + shield_store: ShieldStore + def __init__(self, config: PromptGuardConfig, _deps) -> None: self.config = config @@ -53,7 +56,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): self, shield_id: str, messages: list[Message], - params: dict[str, Any] = None, + params: dict[str, Any], ) -> RunShieldResponse: shield = await self.shield_store.get_shield(shield_id) if not shield: @@ -117,8 +120,10 @@ class PromptGuardShield: elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold: violation = SafetyViolation( violation_level=ViolationLevel.ERROR, - violation_type=f"prompt_injection:malicious={score_malicious}", - violation_return_message="Sorry, I cannot do this.", + user_message="Sorry, I cannot do this.", + metadata={ + "violation_type": f"prompt_injection:malicious={score_malicious}", + }, ) return RunShieldResponse(violation=violation) diff --git a/pyproject.toml b/pyproject.toml index bb079790f..a77ec5ac9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -266,7 +266,6 @@ exclude = [ "^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/llama_guard/", - "^llama_stack/providers/inline/safety/prompt_guard/", "^llama_stack/providers/inline/scoring/basic/", "^llama_stack/providers/inline/scoring/braintrust/", "^llama_stack/providers/inline/scoring/llm_as_judge/",