From fdaec91747c6b86b866a1c0c69ee2f625cd42bf5 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 7 Nov 2024 14:35:04 -0800 Subject: [PATCH] Split safety into (llama-guard, prompt-guard, code-scanner) --- .../code_scanner}/__init__.py | 0 .../code_scanner}/code_scanner.py | 0 .../code_scanner}/config.py | 2 +- .../inline/safety/llama_guard/__init__.py | 19 +++ .../{meta_reference => llama_guard}/config.py | 15 +- .../llama_guard.py | 87 ++++++++--- .../inline/safety/meta_reference/__init__.py | 17 -- .../inline/safety/meta_reference/base.py | 57 ------- .../safety/meta_reference/prompt_guard.py | 145 ------------------ .../inline/safety/meta_reference/safety.py | 107 ------------- .../inline/safety/prompt_guard/__init__.py | 15 ++ .../inline/safety/prompt_guard/config.py | 25 +++ .../safety/prompt_guard/prompt_guard.py | 128 ++++++++++++++++ llama_stack/providers/registry/safety.py | 46 ++++-- 14 files changed, 295 insertions(+), 368 deletions(-) rename llama_stack/providers/inline/{meta_reference/codeshield => safety/code_scanner}/__init__.py (100%) rename llama_stack/providers/inline/{meta_reference/codeshield => safety/code_scanner}/code_scanner.py (100%) rename llama_stack/providers/inline/{meta_reference/codeshield => safety/code_scanner}/config.py (87%) create mode 100644 llama_stack/providers/inline/safety/llama_guard/__init__.py rename llama_stack/providers/inline/safety/{meta_reference => llama_guard}/config.py (75%) rename llama_stack/providers/inline/safety/{meta_reference => llama_guard}/llama_guard.py (76%) delete mode 100644 llama_stack/providers/inline/safety/meta_reference/__init__.py delete mode 100644 llama_stack/providers/inline/safety/meta_reference/base.py delete mode 100644 llama_stack/providers/inline/safety/meta_reference/prompt_guard.py delete mode 100644 llama_stack/providers/inline/safety/meta_reference/safety.py create mode 100644 llama_stack/providers/inline/safety/prompt_guard/__init__.py create mode 100644 llama_stack/providers/inline/safety/prompt_guard/config.py create mode 100644 llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py diff --git a/llama_stack/providers/inline/meta_reference/codeshield/__init__.py b/llama_stack/providers/inline/safety/code_scanner/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/codeshield/__init__.py rename to llama_stack/providers/inline/safety/code_scanner/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/codeshield/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/codeshield/code_scanner.py rename to llama_stack/providers/inline/safety/code_scanner/code_scanner.py diff --git a/llama_stack/providers/inline/meta_reference/codeshield/config.py b/llama_stack/providers/inline/safety/code_scanner/config.py similarity index 87% rename from llama_stack/providers/inline/meta_reference/codeshield/config.py rename to llama_stack/providers/inline/safety/code_scanner/config.py index 583c2c95f..75c90d69a 100644 --- a/llama_stack/providers/inline/meta_reference/codeshield/config.py +++ b/llama_stack/providers/inline/safety/code_scanner/config.py @@ -7,5 +7,5 @@ from pydantic import BaseModel -class CodeShieldConfig(BaseModel): +class CodeScannerConfig(BaseModel): pass diff --git a/llama_stack/providers/inline/safety/llama_guard/__init__.py b/llama_stack/providers/inline/safety/llama_guard/__init__.py new file mode 100644 index 000000000..6024f840c --- /dev/null +++ b/llama_stack/providers/inline/safety/llama_guard/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import LlamaGuardConfig + + +async def get_provider_impl(config: LlamaGuardConfig, deps): + from .llama_guard import LlamaGuardSafetyImpl + + assert isinstance( + config, LlamaGuardConfig + ), f"Unexpected config type: {type(config)}" + + impl = LlamaGuardSafetyImpl(config, deps) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/safety/meta_reference/config.py b/llama_stack/providers/inline/safety/llama_guard/config.py similarity index 75% rename from llama_stack/providers/inline/safety/meta_reference/config.py rename to llama_stack/providers/inline/safety/llama_guard/config.py index 14233ad0c..aec856bce 100644 --- a/llama_stack/providers/inline/safety/meta_reference/config.py +++ b/llama_stack/providers/inline/safety/llama_guard/config.py @@ -4,20 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum -from typing import List, Optional +from typing import List from llama_models.sku_list import CoreModelId, safety_models from pydantic import BaseModel, field_validator -class PromptGuardType(Enum): - injection = "injection" - jailbreak = "jailbreak" - - -class LlamaGuardShieldConfig(BaseModel): +class LlamaGuardConfig(BaseModel): model: str = "Llama-Guard-3-1B" excluded_categories: List[str] = [] @@ -41,8 +35,3 @@ class LlamaGuardShieldConfig(BaseModel): f"Invalid model: {model}. Must be one of {permitted_models}" ) return model - - -class SafetyConfig(BaseModel): - llama_guard_shield: Optional[LlamaGuardShieldConfig] = None - enable_prompt_guard: Optional[bool] = False diff --git a/llama_stack/providers/inline/safety/meta_reference/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py similarity index 76% rename from llama_stack/providers/inline/safety/meta_reference/llama_guard.py rename to llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 99b1c29be..bac153a4d 100644 --- a/llama_stack/providers/inline/safety/meta_reference/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -7,16 +7,67 @@ import re from string import Template -from typing import List, Optional +from typing import Any, Dict, List, Optional from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.distribution.datatypes import Api -from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse +from llama_stack.providers.datatypes import ShieldsProtocolPrivate +from .config import LlamaGuardConfig + + +class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): + def __init__(self, config: LlamaGuardConfig, deps) -> None: + self.config = config + self.inference_api = deps[Api.inference] + + async def initialize(self) -> None: + self.shield = LlamaGuardShield( + model=self.config.model, + inference_api=self.inference_api, + excluded_categories=self.config.excluded_categories, + ) + + async def shutdown(self) -> None: + pass + + async def register_shield(self, shield: ShieldDef) -> None: + raise ValueError("Registering dynamic shields is not supported") + + async def list_shields(self) -> List[ShieldDef]: + return [ + ShieldDef( + identifier=ShieldType.llama_guard.value, + shield_type=ShieldType.llama_guard.value, + params={}, + ), + ] + + async def run_shield( + self, + identifier: str, + messages: List[Message], + params: Dict[str, Any] = None, + ) -> RunShieldResponse: + shield_def = await self.shield_store.get_shield(identifier) + if not shield_def: + raise ValueError(f"Unknown shield {identifier}") + + messages = messages.copy() + # some shields like llama-guard require the first message to be a user message + # since this might be a tool call, first role might not be user + if len(messages) > 0 and messages[0].role != Role.user.value: + messages[0] = UserMessage(content=messages[0].content) + + return await self.shield.run(messages) + + +CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" SAFE_RESPONSE = "safe" -_INSTANCE = None CAT_VIOLENT_CRIMES = "Violent Crimes" CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes" @@ -107,16 +158,13 @@ PROMPT_TEMPLATE = Template( ) -class LlamaGuardShield(ShieldBase): +class LlamaGuardShield: def __init__( self, model: str, inference_api: Inference, - excluded_categories: List[str] = None, - on_violation_action: OnViolationAction = OnViolationAction.RAISE, + excluded_categories: Optional[List[str]] = None, ): - super().__init__(on_violation_action) - if excluded_categories is None: excluded_categories = [] @@ -174,7 +222,7 @@ class LlamaGuardShield(ShieldBase): ) return messages - async def run(self, messages: List[Message]) -> ShieldResponse: + async def run(self, messages: List[Message]) -> RunShieldResponse: messages = self.validate_messages(messages) if self.model == CoreModelId.llama_guard_3_11b_vision.value: @@ -195,8 +243,7 @@ class LlamaGuardShield(ShieldBase): content += event.delta content = content.strip() - shield_response = self.get_shield_response(content) - return shield_response + return self.get_shield_response(content) def build_text_shield_input(self, messages: List[Message]) -> UserMessage: return UserMessage(content=self.build_prompt(messages)) @@ -250,19 +297,23 @@ class LlamaGuardShield(ShieldBase): conversations=conversations_str, ) - def get_shield_response(self, response: str) -> ShieldResponse: + def get_shield_response(self, response: str) -> RunShieldResponse: response = response.strip() if response == SAFE_RESPONSE: - return ShieldResponse(is_violation=False) + return RunShieldResponse(violation=None) + unsafe_code = self.check_unsafe_response(response) if unsafe_code: unsafe_code_list = unsafe_code.split(",") if set(unsafe_code_list).issubset(set(self.excluded_categories)): - return ShieldResponse(is_violation=False) - return ShieldResponse( - is_violation=True, - violation_type=unsafe_code, - violation_return_message=CANNED_RESPONSE_TEXT, + return RunShieldResponse(violation=None) + + return RunShieldResponse( + violation=SafetyViolation( + violation_level=ViolationLevel.ERROR, + user_message=CANNED_RESPONSE_TEXT, + metadata={"violation_type": unsafe_code}, + ), ) raise ValueError(f"Unexpected response: {response}") diff --git a/llama_stack/providers/inline/safety/meta_reference/__init__.py b/llama_stack/providers/inline/safety/meta_reference/__init__.py deleted file mode 100644 index 5e0888de6..000000000 --- a/llama_stack/providers/inline/safety/meta_reference/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import LlamaGuardShieldConfig, SafetyConfig # noqa: F401 - - -async def get_provider_impl(config: SafetyConfig, deps): - from .safety import MetaReferenceSafetyImpl - - assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" - - impl = MetaReferenceSafetyImpl(config, deps) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/safety/meta_reference/base.py b/llama_stack/providers/inline/safety/meta_reference/base.py deleted file mode 100644 index 3861a7c4a..000000000 --- a/llama_stack/providers/inline/safety/meta_reference/base.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from abc import ABC, abstractmethod -from typing import List - -from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message -from pydantic import BaseModel -from llama_stack.apis.safety import * # noqa: F403 - -CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" - - -# TODO: clean this up; just remove this type completely -class ShieldResponse(BaseModel): - is_violation: bool - violation_type: Optional[str] = None - violation_return_message: Optional[str] = None - - -# TODO: this is a caller / agent concern -class OnViolationAction(Enum): - IGNORE = 0 - WARN = 1 - RAISE = 2 - - -class ShieldBase(ABC): - def __init__( - self, - on_violation_action: OnViolationAction = OnViolationAction.RAISE, - ): - self.on_violation_action = on_violation_action - - @abstractmethod - async def run(self, messages: List[Message]) -> ShieldResponse: - raise NotImplementedError() - - -def message_content_as_str(message: Message) -> str: - return interleaved_text_media_as_str(message.content) - - -class TextShield(ShieldBase): - def convert_messages_to_text(self, messages: List[Message]) -> str: - return "\n".join([message_content_as_str(m) for m in messages]) - - async def run(self, messages: List[Message]) -> ShieldResponse: - text = self.convert_messages_to_text(messages) - return await self.run_impl(text) - - @abstractmethod - async def run_impl(self, text: str) -> ShieldResponse: - raise NotImplementedError() diff --git a/llama_stack/providers/inline/safety/meta_reference/prompt_guard.py b/llama_stack/providers/inline/safety/meta_reference/prompt_guard.py deleted file mode 100644 index 54e911418..000000000 --- a/llama_stack/providers/inline/safety/meta_reference/prompt_guard.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import auto, Enum -from typing import List - -import torch - -from llama_models.llama3.api.datatypes import Message -from termcolor import cprint - -from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield - - -class PromptGuardShield(TextShield): - class Mode(Enum): - INJECTION = auto() - JAILBREAK = auto() - - _instances = {} - _model_cache = None - - @staticmethod - def instance( - model_dir: str, - threshold: float = 0.9, - temperature: float = 1.0, - mode: "PromptGuardShield.Mode" = Mode.JAILBREAK, - on_violation_action=OnViolationAction.RAISE, - ) -> "PromptGuardShield": - action_value = on_violation_action.value - key = (model_dir, threshold, temperature, mode, action_value) - if key not in PromptGuardShield._instances: - PromptGuardShield._instances[key] = PromptGuardShield( - model_dir=model_dir, - threshold=threshold, - temperature=temperature, - mode=mode, - on_violation_action=on_violation_action, - ) - return PromptGuardShield._instances[key] - - def __init__( - self, - model_dir: str, - threshold: float = 0.9, - temperature: float = 1.0, - mode: "PromptGuardShield.Mode" = Mode.JAILBREAK, - on_violation_action: OnViolationAction = OnViolationAction.RAISE, - ): - super().__init__(on_violation_action) - assert ( - model_dir is not None - ), "Must provide a model directory for prompt injection shield" - if temperature <= 0: - raise ValueError("Temperature must be greater than 0") - self.device = "cuda" - if PromptGuardShield._model_cache is None: - from transformers import AutoModelForSequenceClassification, AutoTokenizer - - # load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_dir) - model = AutoModelForSequenceClassification.from_pretrained( - model_dir, device_map=self.device - ) - PromptGuardShield._model_cache = (tokenizer, model) - - self.tokenizer, self.model = PromptGuardShield._model_cache - self.temperature = temperature - self.threshold = threshold - self.mode = mode - - def convert_messages_to_text(self, messages: List[Message]) -> str: - return message_content_as_str(messages[-1]) - - async def run_impl(self, text: str) -> ShieldResponse: - # run model on messages and return response - inputs = self.tokenizer(text, return_tensors="pt") - inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()} - with torch.no_grad(): - outputs = self.model(**inputs) - logits = outputs[0] - probabilities = torch.softmax(logits / self.temperature, dim=-1) - score_embedded = probabilities[0, 1].item() - score_malicious = probabilities[0, 2].item() - cprint( - f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}", - color="magenta", - ) - - if self.mode == self.Mode.INJECTION and ( - score_embedded + score_malicious > self.threshold - ): - return ShieldResponse( - is_violation=True, - violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}", - violation_return_message="Sorry, I cannot do this.", - ) - elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold: - return ShieldResponse( - is_violation=True, - violation_type=f"prompt_injection:malicious={score_malicious}", - violation_return_message="Sorry, I cannot do this.", - ) - - return ShieldResponse( - is_violation=False, - ) - - -class JailbreakShield(PromptGuardShield): - def __init__( - self, - model_dir: str, - threshold: float = 0.9, - temperature: float = 1.0, - on_violation_action: OnViolationAction = OnViolationAction.RAISE, - ): - super().__init__( - model_dir=model_dir, - threshold=threshold, - temperature=temperature, - mode=PromptGuardShield.Mode.JAILBREAK, - on_violation_action=on_violation_action, - ) - - -class InjectionShield(PromptGuardShield): - def __init__( - self, - model_dir: str, - threshold: float = 0.9, - temperature: float = 1.0, - on_violation_action: OnViolationAction = OnViolationAction.RAISE, - ): - super().__init__( - model_dir=model_dir, - threshold=threshold, - temperature=temperature, - mode=PromptGuardShield.Mode.INJECTION, - on_violation_action=on_violation_action, - ) diff --git a/llama_stack/providers/inline/safety/meta_reference/safety.py b/llama_stack/providers/inline/safety/meta_reference/safety.py deleted file mode 100644 index 824a7cd7e..000000000 --- a/llama_stack/providers/inline/safety/meta_reference/safety.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict, List - -from llama_stack.distribution.utils.model_utils import model_local_dir -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.distribution.datatypes import Api - -from llama_stack.providers.datatypes import ShieldsProtocolPrivate - -from .base import OnViolationAction, ShieldBase -from .config import SafetyConfig -from .llama_guard import LlamaGuardShield -from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield - - -PROMPT_GUARD_MODEL = "Prompt-Guard-86M" -SUPPORTED_SHIELDS = [ShieldType.llama_guard, ShieldType.prompt_guard] - - -class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): - def __init__(self, config: SafetyConfig, deps) -> None: - self.config = config - self.inference_api = deps[Api.inference] - - self.available_shields = [] - if config.llama_guard_shield: - self.available_shields.append(ShieldType.llama_guard) - if config.enable_prompt_guard: - self.available_shields.append(ShieldType.prompt_guard) - - async def initialize(self) -> None: - if self.config.enable_prompt_guard: - model_dir = model_local_dir(PROMPT_GUARD_MODEL) - _ = PromptGuardShield.instance(model_dir) - - async def shutdown(self) -> None: - pass - - async def register_shield(self, shield: Shield) -> None: - if shield.shield_type not in self.available_shields: - raise ValueError(f"Shield type {shield.shield_type} not supported") - - async def run_shield( - self, - shield_id: str, - messages: List[Message], - params: Dict[str, Any] = None, - ) -> RunShieldResponse: - shield = await self.shield_store.get_shield(shield_id) - if not shield: - raise ValueError(f"Shield {shield_id} not found") - - shield_impl = self.get_shield_impl(shield) - - messages = messages.copy() - # some shields like llama-guard require the first message to be a user message - # since this might be a tool call, first role might not be user - if len(messages) > 0 and messages[0].role != Role.user.value: - messages[0] = UserMessage(content=messages[0].content) - - # TODO: we can refactor ShieldBase, etc. to be inline with the API types - res = await shield_impl.run(messages) - violation = None - if ( - res.is_violation - and shield_impl.on_violation_action != OnViolationAction.IGNORE - ): - violation = SafetyViolation( - violation_level=( - ViolationLevel.ERROR - if shield_impl.on_violation_action == OnViolationAction.RAISE - else ViolationLevel.WARN - ), - user_message=res.violation_return_message, - metadata={ - "violation_type": res.violation_type, - }, - ) - - return RunShieldResponse(violation=violation) - - def get_shield_impl(self, shield: Shield) -> ShieldBase: - if shield.shield_type == ShieldType.llama_guard: - cfg = self.config.llama_guard_shield - return LlamaGuardShield( - model=cfg.model, - inference_api=self.inference_api, - excluded_categories=cfg.excluded_categories, - ) - elif shield.shield_type == ShieldType.prompt_guard: - model_dir = model_local_dir(PROMPT_GUARD_MODEL) - subtype = shield.params.get("prompt_guard_type", "injection") - if subtype == "injection": - return InjectionShield.instance(model_dir) - elif subtype == "jailbreak": - return JailbreakShield.instance(model_dir) - else: - raise ValueError(f"Unknown prompt guard type: {subtype}") - else: - raise ValueError(f"Unknown shield type: {shield.shield_type}") diff --git a/llama_stack/providers/inline/safety/prompt_guard/__init__.py b/llama_stack/providers/inline/safety/prompt_guard/__init__.py new file mode 100644 index 000000000..087aca6d9 --- /dev/null +++ b/llama_stack/providers/inline/safety/prompt_guard/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import PromptGuardConfig # noqa: F401 + + +async def get_provider_impl(config: PromptGuardConfig, deps): + from .prompt_guard import PromptGuardSafetyImpl + + impl = PromptGuardSafetyImpl(config, deps) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/safety/prompt_guard/config.py b/llama_stack/providers/inline/safety/prompt_guard/config.py new file mode 100644 index 000000000..bddd28452 --- /dev/null +++ b/llama_stack/providers/inline/safety/prompt_guard/config.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum + +from pydantic import BaseModel, field_validator + + +class PromptGuardType(Enum): + injection = "injection" + jailbreak = "jailbreak" + + +class PromptGuardConfig(BaseModel): + guard_type: str = PromptGuardType.injection.value + + @classmethod + @field_validator("guard_type") + def validate_guard_type(cls, v): + if v not in [t.value for t in PromptGuardType]: + raise ValueError(f"Unknown prompt guard type: {v}") + return v diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py new file mode 100644 index 000000000..5cfafcde4 --- /dev/null +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List + +import torch + +from llama_stack.distribution.utils.model_utils import model_local_dir +from termcolor import cprint + +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.safety import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 + +from llama_stack.providers.datatypes import ShieldsProtocolPrivate + +from .config import PromptGuardConfig, PromptGuardType + + +PROMPT_GUARD_MODEL = "Prompt-Guard-86M" + + +class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): + def __init__(self, config: PromptGuardConfig, _deps) -> None: + self.config = config + + async def initialize(self) -> None: + model_dir = model_local_dir(PROMPT_GUARD_MODEL) + self.shield = PromptGuardShield(model_dir, self.config) + + async def shutdown(self) -> None: + pass + + async def register_shield(self, shield: ShieldDef) -> None: + raise ValueError("Registering dynamic shields is not supported") + + async def list_shields(self) -> List[ShieldDef]: + return [ + ShieldDef( + identifier=ShieldType.prompt_guard.value, + shield_type=ShieldType.prompt_guard.value, + params={}, + ) + ] + + async def run_shield( + self, + identifier: str, + messages: List[Message], + params: Dict[str, Any] = None, + ) -> RunShieldResponse: + shield_def = await self.shield_store.get_shield(identifier) + if not shield_def: + raise ValueError(f"Unknown shield {identifier}") + + return await self.shield.run(messages) + + +class PromptGuardShield: + def __init__( + self, + model_dir: str, + config: PromptGuardConfig, + threshold: float = 0.9, + temperature: float = 1.0, + ): + assert ( + model_dir is not None + ), "Must provide a model directory for prompt injection shield" + if temperature <= 0: + raise ValueError("Temperature must be greater than 0") + + self.config = config + self.temperature = temperature + self.threshold = threshold + + self.device = "cuda" + + # load model and tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(model_dir) + self.model = AutoModelForSequenceClassification.from_pretrained( + model_dir, device_map=self.device + ) + + async def run(self, messages: List[Message]) -> RunShieldResponse: + message = messages[-1] + text = interleaved_text_media_as_str(message.content) + + # run model on messages and return response + inputs = self.tokenizer(text, return_tensors="pt") + inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()} + with torch.no_grad(): + outputs = self.model(**inputs) + logits = outputs[0] + probabilities = torch.softmax(logits / self.temperature, dim=-1) + score_embedded = probabilities[0, 1].item() + score_malicious = probabilities[0, 2].item() + cprint( + f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}", + color="magenta", + ) + + violation = None + if self.config.guard_type == PromptGuardType.injection.value and ( + score_embedded + score_malicious > self.threshold + ): + violation = SafetyViolation( + violation_level=ViolationLevel.ERROR, + user_message="Sorry, I cannot do this.", + metadata={ + "violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}", + }, + ) + elif ( + self.config.guard_type == PromptGuardType.jailbreak.value + and score_malicious > self.threshold + ): + violation = SafetyViolation( + violation_level=ViolationLevel.ERROR, + violation_type=f"prompt_injection:malicious={score_malicious}", + violation_return_message="Sorry, I cannot do this.", + ) + + return RunShieldResponse(violation=violation) diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index fb5b6695a..668419338 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -29,6 +29,42 @@ def available_providers() -> List[ProviderSpec]: api_dependencies=[ Api.inference, ], + deprecation_warning="Please use the `llama-guard` / `prompt-guard` / `code-scanner` providers instead.", + ), + InlineProviderSpec( + api=Api.safety, + provider_type="llama-guard", + pip_packages=[], + module="llama_stack.providers.inline.safety.llama_guard", + config_class="llama_stack.providers.inline.safety.llama_guard.LlamaGuardConfig", + api_dependencies=[ + Api.inference, + ], + ), + InlineProviderSpec( + api=Api.safety, + provider_type="prompt-guard", + pip_packages=[ + "transformers", + "torch --index-url https://download.pytorch.org/whl/cpu", + ], + module="llama_stack.providers.inline.safety.prompt_guard", + config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig", + api_dependencies=[ + Api.inference, + ], + ), + InlineProviderSpec( + api=Api.safety, + provider_type="code-scanner", + pip_packages=[ + "codeshield", + ], + module="llama_stack.providers.inline.safety.code_scanner", + config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig", + api_dependencies=[ + Api.inference, + ], ), remote_provider_spec( api=Api.safety, @@ -48,14 +84,4 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", ), ), - InlineProviderSpec( - api=Api.safety, - provider_type="meta-reference/codeshield", - pip_packages=[ - "codeshield", - ], - module="llama_stack.providers.inline.safety.meta_reference", - config_class="llama_stack.providers.inline.safety.meta_reference.CodeShieldConfig", - api_dependencies=[], - ), ]