chore: add mypy prompt guard (#2678)

# What does this PR do?
<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->
This PR adds static type coverage to `llama-stack`

Part of https://github.com/meta-llama/llama-stack/issues/2647

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-08-11 17:40:40 +02:00 committed by GitHub
parent 7448a4a88c
commit b5b5f5b9ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 4 deletions

View file

@ -15,6 +15,7 @@ from llama_stack.apis.safety import (
RunShieldResponse, RunShieldResponse,
Safety, Safety,
SafetyViolation, SafetyViolation,
ShieldStore,
ViolationLevel, ViolationLevel,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
@ -32,6 +33,8 @@ PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
shield_store: ShieldStore
def __init__(self, config: PromptGuardConfig, _deps) -> None: def __init__(self, config: PromptGuardConfig, _deps) -> None:
self.config = config self.config = config
@ -53,7 +56,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
self, self,
shield_id: str, shield_id: str,
messages: list[Message], messages: list[Message],
params: dict[str, Any] = None, params: dict[str, Any],
) -> RunShieldResponse: ) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id) shield = await self.shield_store.get_shield(shield_id)
if not shield: if not shield:
@ -117,8 +120,10 @@ class PromptGuardShield:
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold: elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
violation = SafetyViolation( violation = SafetyViolation(
violation_level=ViolationLevel.ERROR, violation_level=ViolationLevel.ERROR,
violation_type=f"prompt_injection:malicious={score_malicious}", user_message="Sorry, I cannot do this.",
violation_return_message="Sorry, I cannot do this.", metadata={
"violation_type": f"prompt_injection:malicious={score_malicious}",
},
) )
return RunShieldResponse(violation=violation) return RunShieldResponse(violation=violation)

View file

@ -266,7 +266,6 @@ exclude = [
"^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/post_training/common/validator\\.py$",
"^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/code_scanner/",
"^llama_stack/providers/inline/safety/llama_guard/", "^llama_stack/providers/inline/safety/llama_guard/",
"^llama_stack/providers/inline/safety/prompt_guard/",
"^llama_stack/providers/inline/scoring/basic/", "^llama_stack/providers/inline/scoring/basic/",
"^llama_stack/providers/inline/scoring/braintrust/", "^llama_stack/providers/inline/scoring/braintrust/",
"^llama_stack/providers/inline/scoring/llm_as_judge/", "^llama_stack/providers/inline/scoring/llm_as_judge/",