mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-26 22:19:49 +00:00
chore: add mypy prompt guard
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
1463b79218
commit
bd05e22004
3 changed files with 29 additions and 6 deletions
|
@ -15,6 +15,7 @@ from llama_stack.apis.safety import (
|
||||||
RunShieldResponse,
|
RunShieldResponse,
|
||||||
Safety,
|
Safety,
|
||||||
SafetyViolation,
|
SafetyViolation,
|
||||||
|
ShieldStore,
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
@ -32,6 +33,8 @@ PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
|
shield_store: ShieldStore
|
||||||
|
|
||||||
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@ -50,7 +53,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
params: dict[str, Any] = None,
|
params: dict[str, Any],
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
shield = await self.shield_store.get_shield(shield_id)
|
shield = await self.shield_store.get_shield(shield_id)
|
||||||
if not shield:
|
if not shield:
|
||||||
|
@ -114,8 +117,10 @@ class PromptGuardShield:
|
||||||
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
|
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
|
||||||
violation = SafetyViolation(
|
violation = SafetyViolation(
|
||||||
violation_level=ViolationLevel.ERROR,
|
violation_level=ViolationLevel.ERROR,
|
||||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
user_message="Sorry, I cannot do this.",
|
||||||
violation_return_message="Sorry, I cannot do this.",
|
metadata={
|
||||||
|
"violation_type": f"prompt_injection:malicious={score_malicious}",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
|
@ -24,10 +24,29 @@ def parse_hf_params(dataset_def: Dataset):
|
||||||
uri = dataset_def.source.uri
|
uri = dataset_def.source.uri
|
||||||
parsed_uri = urlparse(uri)
|
parsed_uri = urlparse(uri)
|
||||||
params = parse_qs(parsed_uri.query)
|
params = parse_qs(parsed_uri.query)
|
||||||
params = {k: v[0] for k, v in params.items()}
|
|
||||||
|
# Convert parameters to appropriate types
|
||||||
|
processed_params = {}
|
||||||
|
for k, v in params.items():
|
||||||
|
# Parameters that should remain as arrays/lists
|
||||||
|
if k in ("paths", "data_files"):
|
||||||
|
processed_params[k] = v # Keep as list
|
||||||
|
# Parameters that should be booleans
|
||||||
|
elif k in ("expand", "streaming", "trust_remote_code"):
|
||||||
|
processed_params[k] = v[0].lower() in ("true", "1", "yes")
|
||||||
|
# Parameters that should be integers
|
||||||
|
elif k in ("num_proc", "batch_size"):
|
||||||
|
try:
|
||||||
|
processed_params[k] = int(v[0])
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Parameter '{k}' must be an integer, got '{v[0]}'") from None
|
||||||
|
# All other parameters remain as strings
|
||||||
|
else:
|
||||||
|
processed_params[k] = v[0]
|
||||||
|
|
||||||
path = parsed_uri.path.lstrip("/")
|
path = parsed_uri.path.lstrip("/")
|
||||||
|
|
||||||
return path, params
|
return path, processed_params
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
|
|
|
@ -261,7 +261,6 @@ exclude = [
|
||||||
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
||||||
"^llama_stack/providers/inline/safety/code_scanner/",
|
"^llama_stack/providers/inline/safety/code_scanner/",
|
||||||
"^llama_stack/providers/inline/safety/llama_guard/",
|
"^llama_stack/providers/inline/safety/llama_guard/",
|
||||||
"^llama_stack/providers/inline/safety/prompt_guard/",
|
|
||||||
"^llama_stack/providers/inline/scoring/basic/",
|
"^llama_stack/providers/inline/scoring/basic/",
|
||||||
"^llama_stack/providers/inline/scoring/braintrust/",
|
"^llama_stack/providers/inline/scoring/braintrust/",
|
||||||
"^llama_stack/providers/inline/scoring/llm_as_judge/",
|
"^llama_stack/providers/inline/scoring/llm_as_judge/",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue