mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
chore: add mypy
prompt guard (#2678)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> This PR adds static type coverage to `llama-stack` Part of https://github.com/meta-llama/llama-stack/issues/2647 <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
7448a4a88c
commit
b5b5f5b9ae
2 changed files with 8 additions and 4 deletions
|
@ -15,6 +15,7 @@ from llama_stack.apis.safety import (
|
|||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ShieldStore,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
@ -32,6 +33,8 @@ PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
|||
|
||||
|
||||
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
shield_store: ShieldStore
|
||||
|
||||
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
||||
self.config = config
|
||||
|
||||
|
@ -53,7 +56,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
@ -117,8 +120,10 @@ class PromptGuardShield:
|
|||
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||
violation_return_message="Sorry, I cannot do this.",
|
||||
user_message="Sorry, I cannot do this.",
|
||||
metadata={
|
||||
"violation_type": f"prompt_injection:malicious={score_malicious}",
|
||||
},
|
||||
)
|
||||
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
|
|
@ -266,7 +266,6 @@ exclude = [
|
|||
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
||||
"^llama_stack/providers/inline/safety/code_scanner/",
|
||||
"^llama_stack/providers/inline/safety/llama_guard/",
|
||||
"^llama_stack/providers/inline/safety/prompt_guard/",
|
||||
"^llama_stack/providers/inline/scoring/basic/",
|
||||
"^llama_stack/providers/inline/scoring/braintrust/",
|
||||
"^llama_stack/providers/inline/scoring/llm_as_judge/",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue