Split safety into (llama-guard, prompt-guard, code-scanner) (#400)

Splits the meta-reference safety implementation into three distinct providers:

- inline::llama-guard
- inline::prompt-guard
- inline::code-scanner

Note that this PR is a backward incompatible change to the llama stack server. I have added deprecation_error field to ProviderSpec -- the server reads it and immediately barfs. This is used to direct the user with a specific message on what action to perform. An automagical "config upgrade" is a bit too much work to implement right now :/

(Note that we will be gradually prefixing all inline providers with inline:: -- I am only doing this for this set of new providers because otherwise existing configuration files will break even more badly.)
This commit is contained in:
Ashwin Bharambe 2024-11-11 09:29:18 -08:00 committed by GitHub
parent 6d38b1690b
commit c1f7ba3aed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
47 changed files with 464 additions and 500 deletions

View file

@ -25,7 +25,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
pass
async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.code_scanner.value:
if shield.shield_type != ShieldType.code_scanner:
raise ValueError(f"Unsupported safety shield type: {shield.shield_type}")
async def run_shield(

View file

@ -7,5 +7,5 @@
from pydantic import BaseModel
class CodeShieldConfig(BaseModel):
class CodeScannerConfig(BaseModel):
pass

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps):
from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(
config, LlamaGuardConfig
), f"Unexpected config type: {type(config)}"
impl = LlamaGuardSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -4,20 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Optional
from typing import List
from llama_models.sku_list import CoreModelId, safety_models
from pydantic import BaseModel, field_validator
class PromptGuardType(Enum):
injection = "injection"
jailbreak = "jailbreak"
class LlamaGuardShieldConfig(BaseModel):
class LlamaGuardConfig(BaseModel):
model: str = "Llama-Guard-3-1B"
excluded_categories: List[str] = []
@ -41,8 +35,3 @@ class LlamaGuardShieldConfig(BaseModel):
f"Invalid model: {model}. Must be one of {permitted_models}"
)
return model
class SafetyConfig(BaseModel):
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
enable_prompt_guard: Optional[bool] = False

View file

@ -7,16 +7,21 @@
import re
from string import Template
from typing import List, Optional
from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import LlamaGuardConfig
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
SAFE_RESPONSE = "safe"
_INSTANCE = None
CAT_VIOLENT_CRIMES = "Violent Crimes"
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
@ -107,16 +112,52 @@ PROMPT_TEMPLATE = Template(
)
class LlamaGuardShield(ShieldBase):
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
self.shield = LlamaGuardShield(
model=self.config.model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
)
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
print(f"Registering shield {shield}")
if shield.shield_type != ShieldType.llama_guard:
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Unknown shield {shield_id}")
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
return await self.shield.run(messages)
class LlamaGuardShield:
def __init__(
self,
model: str,
inference_api: Inference,
excluded_categories: List[str] = None,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
excluded_categories: Optional[List[str]] = None,
):
super().__init__(on_violation_action)
if excluded_categories is None:
excluded_categories = []
@ -174,7 +215,7 @@ class LlamaGuardShield(ShieldBase):
)
return messages
async def run(self, messages: List[Message]) -> ShieldResponse:
async def run(self, messages: List[Message]) -> RunShieldResponse:
messages = self.validate_messages(messages)
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
@ -195,8 +236,7 @@ class LlamaGuardShield(ShieldBase):
content += event.delta
content = content.strip()
shield_response = self.get_shield_response(content)
return shield_response
return self.get_shield_response(content)
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
return UserMessage(content=self.build_prompt(messages))
@ -250,19 +290,23 @@ class LlamaGuardShield(ShieldBase):
conversations=conversations_str,
)
def get_shield_response(self, response: str) -> ShieldResponse:
def get_shield_response(self, response: str) -> RunShieldResponse:
response = response.strip()
if response == SAFE_RESPONSE:
return ShieldResponse(is_violation=False)
return RunShieldResponse(violation=None)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(is_violation=False)
return ShieldResponse(
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
return RunShieldResponse(violation=None)
return RunShieldResponse(
violation=SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message=CANNED_RESPONSE_TEXT,
metadata={"violation_type": unsafe_code},
),
)
raise ValueError(f"Unexpected response: {response}")

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import LlamaGuardShieldConfig, SafetyConfig # noqa: F401
async def get_provider_impl(config: SafetyConfig, deps):
from .safety import MetaReferenceSafetyImpl
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -1,57 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from pydantic import BaseModel
from llama_stack.apis.safety import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
# TODO: clean this up; just remove this type completely
class ShieldResponse(BaseModel):
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None
# TODO: this is a caller / agent concern
class OnViolationAction(Enum):
IGNORE = 0
WARN = 1
RAISE = 2
class ShieldBase(ABC):
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
self.on_violation_action = on_violation_action
@abstractmethod
async def run(self, messages: List[Message]) -> ShieldResponse:
raise NotImplementedError()
def message_content_as_str(message: Message) -> str:
return interleaved_text_media_as_str(message.content)
class TextShield(ShieldBase):
def convert_messages_to_text(self, messages: List[Message]) -> str:
return "\n".join([message_content_as_str(m) for m in messages])
async def run(self, messages: List[Message]) -> ShieldResponse:
text = self.convert_messages_to_text(messages)
return await self.run_impl(text)
@abstractmethod
async def run_impl(self, text: str) -> ShieldResponse:
raise NotImplementedError()

View file

@ -1,145 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import auto, Enum
from typing import List
import torch
from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
class PromptGuardShield(TextShield):
class Mode(Enum):
INJECTION = auto()
JAILBREAK = auto()
_instances = {}
_model_cache = None
@staticmethod
def instance(
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action=OnViolationAction.RAISE,
) -> "PromptGuardShield":
action_value = on_violation_action.value
key = (model_dir, threshold, temperature, mode, action_value)
if key not in PromptGuardShield._instances:
PromptGuardShield._instances[key] = PromptGuardShield(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=mode,
on_violation_action=on_violation_action,
)
return PromptGuardShield._instances[key]
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
assert (
model_dir is not None
), "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.device = "cuda"
if PromptGuardShield._model_cache is None:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(
model_dir, device_map=self.device
)
PromptGuardShield._model_cache = (tokenizer, model)
self.tokenizer, self.model = PromptGuardShield._model_cache
self.temperature = temperature
self.threshold = threshold
self.mode = mode
def convert_messages_to_text(self, messages: List[Message]) -> str:
return message_content_as_str(messages[-1])
async def run_impl(self, text: str) -> ShieldResponse:
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs[0]
probabilities = torch.softmax(logits / self.temperature, dim=-1)
score_embedded = probabilities[0, 1].item()
score_malicious = probabilities[0, 2].item()
cprint(
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
color="magenta",
)
if self.mode == self.Mode.INJECTION and (
score_embedded + score_malicious > self.threshold
):
return ShieldResponse(
is_violation=True,
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
return ShieldResponse(
is_violation=True,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
return ShieldResponse(
is_violation=False,
)
class JailbreakShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.JAILBREAK,
on_violation_action=on_violation_action,
)
class InjectionShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.INJECTION,
on_violation_action=on_violation_action,
)

View file

@ -1,107 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .base import OnViolationAction, ShieldBase
from .config import SafetyConfig
from .llama_guard import LlamaGuardShield
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
SUPPORTED_SHIELDS = [ShieldType.llama_guard, ShieldType.prompt_guard]
class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
self.available_shields = []
if config.llama_guard_shield:
self.available_shields.append(ShieldType.llama_guard)
if config.enable_prompt_guard:
self.available_shields.append(ShieldType.prompt_guard)
async def initialize(self) -> None:
if self.config.enable_prompt_guard:
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
_ = PromptGuardShield.instance(model_dir)
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if shield.shield_type not in self.available_shields:
raise ValueError(f"Shield type {shield.shield_type} not supported")
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
shield_impl = self.get_shield_impl(shield)
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
res = await shield_impl.run(messages)
violation = None
if (
res.is_violation
and shield_impl.on_violation_action != OnViolationAction.IGNORE
):
violation = SafetyViolation(
violation_level=(
ViolationLevel.ERROR
if shield_impl.on_violation_action == OnViolationAction.RAISE
else ViolationLevel.WARN
),
user_message=res.violation_return_message,
metadata={
"violation_type": res.violation_type,
},
)
return RunShieldResponse(violation=violation)
def get_shield_impl(self, shield: Shield) -> ShieldBase:
if shield.shield_type == ShieldType.llama_guard:
cfg = self.config.llama_guard_shield
return LlamaGuardShield(
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
)
elif shield.shield_type == ShieldType.prompt_guard:
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
subtype = shield.params.get("prompt_guard_type", "injection")
if subtype == "injection":
return InjectionShield.instance(model_dir)
elif subtype == "jailbreak":
return JailbreakShield.instance(model_dir)
else:
raise ValueError(f"Unknown prompt guard type: {subtype}")
else:
raise ValueError(f"Unknown shield type: {shield.shield_type}")

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import PromptGuardConfig # noqa: F401
async def get_provider_impl(config: PromptGuardConfig, deps):
from .prompt_guard import PromptGuardSafetyImpl
impl = PromptGuardSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from pydantic import BaseModel, field_validator
class PromptGuardType(Enum):
injection = "injection"
jailbreak = "jailbreak"
class PromptGuardConfig(BaseModel):
guard_type: str = PromptGuardType.injection.value
@classmethod
@field_validator("guard_type")
def validate_guard_type(cls, v):
if v not in [t.value for t in PromptGuardType]:
raise ValueError(f"Unknown prompt guard type: {v}")
return v

View file

@ -0,0 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
import torch
from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import PromptGuardConfig, PromptGuardType
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: PromptGuardConfig, _deps) -> None:
self.config = config
async def initialize(self) -> None:
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
self.shield = PromptGuardShield(model_dir, self.config)
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.prompt_guard:
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Unknown shield {shield_id}")
return await self.shield.run(messages)
class PromptGuardShield:
def __init__(
self,
model_dir: str,
config: PromptGuardConfig,
threshold: float = 0.9,
temperature: float = 1.0,
):
assert (
model_dir is not None
), "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.config = config
self.temperature = temperature
self.threshold = threshold
self.device = "cuda"
# load model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_dir, device_map=self.device
)
async def run(self, messages: List[Message]) -> RunShieldResponse:
message = messages[-1]
text = interleaved_text_media_as_str(message.content)
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs[0]
probabilities = torch.softmax(logits / self.temperature, dim=-1)
score_embedded = probabilities[0, 1].item()
score_malicious = probabilities[0, 2].item()
cprint(
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
color="magenta",
)
violation = None
if self.config.guard_type == PromptGuardType.injection.value and (
score_embedded + score_malicious > self.threshold
):
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message="Sorry, I cannot do this.",
metadata={
"violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
},
)
elif (
self.config.guard_type == PromptGuardType.jailbreak.value
and score_malicious > self.threshold
):
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
return RunShieldResponse(violation=violation)