diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index ff87889ea..f94da4dc1 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -15,6 +15,7 @@ from llama_stack.apis.safety import ( RunShieldResponse, Safety, SafetyViolation, + ShieldStore, ViolationLevel, ) from llama_stack.apis.shields import Shield @@ -32,6 +33,8 @@ PROMPT_GUARD_MODEL = "Prompt-Guard-86M" class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): + shield_store: ShieldStore + def __init__(self, config: PromptGuardConfig, _deps) -> None: self.config = config @@ -50,7 +53,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): self, shield_id: str, messages: list[Message], - params: dict[str, Any] = None, + params: dict[str, Any], ) -> RunShieldResponse: shield = await self.shield_store.get_shield(shield_id) if not shield: @@ -114,8 +117,10 @@ class PromptGuardShield: elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold: violation = SafetyViolation( violation_level=ViolationLevel.ERROR, - violation_type=f"prompt_injection:malicious={score_malicious}", - violation_return_message="Sorry, I cannot do this.", + user_message="Sorry, I cannot do this.", + metadata={ + "violation_type": f"prompt_injection:malicious={score_malicious}", + }, ) return RunShieldResponse(violation=violation) diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index fafd1d8ff..bac3128bf 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -24,10 +24,29 @@ def parse_hf_params(dataset_def: Dataset): uri = dataset_def.source.uri parsed_uri = urlparse(uri) params = parse_qs(parsed_uri.query) - params = {k: v[0] for k, v in params.items()} + + # Convert parameters to appropriate types + processed_params = {} + for k, v in params.items(): + # Parameters that should remain as arrays/lists + if k in ("paths", "data_files"): + processed_params[k] = v # Keep as list + # Parameters that should be booleans + elif k in ("expand", "streaming", "trust_remote_code"): + processed_params[k] = v[0].lower() in ("true", "1", "yes") + # Parameters that should be integers + elif k in ("num_proc", "batch_size"): + try: + processed_params[k] = int(v[0]) + except ValueError: + raise ValueError(f"Parameter '{k}' must be an integer, got '{v[0]}'") from None + # All other parameters remain as strings + else: + processed_params[k] = v[0] + path = parsed_uri.path.lstrip("/") - return path, params + return path, processed_params class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): diff --git a/pyproject.toml b/pyproject.toml index ad4bb7314..6fd0c3773 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -261,7 +261,6 @@ exclude = [ "^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/llama_guard/", - "^llama_stack/providers/inline/safety/prompt_guard/", "^llama_stack/providers/inline/scoring/basic/", "^llama_stack/providers/inline/scoring/braintrust/", "^llama_stack/providers/inline/scoring/llm_as_judge/",