mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
Split safety into (llama-guard, prompt-guard, code-scanner)
This commit is contained in:
parent
6d38b1690b
commit
fdaec91747
14 changed files with 295 additions and 368 deletions
|
@ -7,5 +7,5 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class CodeShieldConfig(BaseModel):
|
class CodeScannerConfig(BaseModel):
|
||||||
pass
|
pass
|
19
llama_stack/providers/inline/safety/llama_guard/__init__.py
Normal file
19
llama_stack/providers/inline/safety/llama_guard/__init__.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import LlamaGuardConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: LlamaGuardConfig, deps):
|
||||||
|
from .llama_guard import LlamaGuardSafetyImpl
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, LlamaGuardConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
|
impl = LlamaGuardSafetyImpl(config, deps)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -4,20 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from typing import List
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from llama_models.sku_list import CoreModelId, safety_models
|
from llama_models.sku_list import CoreModelId, safety_models
|
||||||
|
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardType(Enum):
|
class LlamaGuardConfig(BaseModel):
|
||||||
injection = "injection"
|
|
||||||
jailbreak = "jailbreak"
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardShieldConfig(BaseModel):
|
|
||||||
model: str = "Llama-Guard-3-1B"
|
model: str = "Llama-Guard-3-1B"
|
||||||
excluded_categories: List[str] = []
|
excluded_categories: List[str] = []
|
||||||
|
|
||||||
|
@ -41,8 +35,3 @@ class LlamaGuardShieldConfig(BaseModel):
|
||||||
f"Invalid model: {model}. Must be one of {permitted_models}"
|
f"Invalid model: {model}. Must be one of {permitted_models}"
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class SafetyConfig(BaseModel):
|
|
||||||
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
|
|
||||||
enable_prompt_guard: Optional[bool] = False
|
|
|
@ -7,16 +7,67 @@
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from string import Template
|
from string import Template
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
|
||||||
|
from .config import LlamaGuardConfig
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
|
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.inference_api = deps[Api.inference]
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self.shield = LlamaGuardShield(
|
||||||
|
model=self.config.model,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
excluded_categories=self.config.excluded_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
|
raise ValueError("Registering dynamic shields is not supported")
|
||||||
|
|
||||||
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
|
return [
|
||||||
|
ShieldDef(
|
||||||
|
identifier=ShieldType.llama_guard.value,
|
||||||
|
shield_type=ShieldType.llama_guard.value,
|
||||||
|
params={},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self,
|
||||||
|
identifier: str,
|
||||||
|
messages: List[Message],
|
||||||
|
params: Dict[str, Any] = None,
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
shield_def = await self.shield_store.get_shield(identifier)
|
||||||
|
if not shield_def:
|
||||||
|
raise ValueError(f"Unknown shield {identifier}")
|
||||||
|
|
||||||
|
messages = messages.copy()
|
||||||
|
# some shields like llama-guard require the first message to be a user message
|
||||||
|
# since this might be a tool call, first role might not be user
|
||||||
|
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||||
|
messages[0] = UserMessage(content=messages[0].content)
|
||||||
|
|
||||||
|
return await self.shield.run(messages)
|
||||||
|
|
||||||
|
|
||||||
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
||||||
SAFE_RESPONSE = "safe"
|
SAFE_RESPONSE = "safe"
|
||||||
_INSTANCE = None
|
|
||||||
|
|
||||||
CAT_VIOLENT_CRIMES = "Violent Crimes"
|
CAT_VIOLENT_CRIMES = "Violent Crimes"
|
||||||
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
|
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
|
||||||
|
@ -107,16 +158,13 @@ PROMPT_TEMPLATE = Template(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardShield(ShieldBase):
|
class LlamaGuardShield:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
excluded_categories: List[str] = None,
|
excluded_categories: Optional[List[str]] = None,
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
|
||||||
):
|
):
|
||||||
super().__init__(on_violation_action)
|
|
||||||
|
|
||||||
if excluded_categories is None:
|
if excluded_categories is None:
|
||||||
excluded_categories = []
|
excluded_categories = []
|
||||||
|
|
||||||
|
@ -174,7 +222,7 @@ class LlamaGuardShield(ShieldBase):
|
||||||
)
|
)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||||
messages = self.validate_messages(messages)
|
messages = self.validate_messages(messages)
|
||||||
|
|
||||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||||
|
@ -195,8 +243,7 @@ class LlamaGuardShield(ShieldBase):
|
||||||
content += event.delta
|
content += event.delta
|
||||||
|
|
||||||
content = content.strip()
|
content = content.strip()
|
||||||
shield_response = self.get_shield_response(content)
|
return self.get_shield_response(content)
|
||||||
return shield_response
|
|
||||||
|
|
||||||
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
|
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
|
||||||
return UserMessage(content=self.build_prompt(messages))
|
return UserMessage(content=self.build_prompt(messages))
|
||||||
|
@ -250,19 +297,23 @@ class LlamaGuardShield(ShieldBase):
|
||||||
conversations=conversations_str,
|
conversations=conversations_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_shield_response(self, response: str) -> ShieldResponse:
|
def get_shield_response(self, response: str) -> RunShieldResponse:
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response == SAFE_RESPONSE:
|
if response == SAFE_RESPONSE:
|
||||||
return ShieldResponse(is_violation=False)
|
return RunShieldResponse(violation=None)
|
||||||
|
|
||||||
unsafe_code = self.check_unsafe_response(response)
|
unsafe_code = self.check_unsafe_response(response)
|
||||||
if unsafe_code:
|
if unsafe_code:
|
||||||
unsafe_code_list = unsafe_code.split(",")
|
unsafe_code_list = unsafe_code.split(",")
|
||||||
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||||
return ShieldResponse(is_violation=False)
|
return RunShieldResponse(violation=None)
|
||||||
return ShieldResponse(
|
|
||||||
is_violation=True,
|
return RunShieldResponse(
|
||||||
violation_type=unsafe_code,
|
violation=SafetyViolation(
|
||||||
violation_return_message=CANNED_RESPONSE_TEXT,
|
violation_level=ViolationLevel.ERROR,
|
||||||
|
user_message=CANNED_RESPONSE_TEXT,
|
||||||
|
metadata={"violation_type": unsafe_code},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
raise ValueError(f"Unexpected response: {response}")
|
raise ValueError(f"Unexpected response: {response}")
|
|
@ -1,17 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from .config import LlamaGuardShieldConfig, SafetyConfig # noqa: F401
|
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: SafetyConfig, deps):
|
|
||||||
from .safety import MetaReferenceSafetyImpl
|
|
||||||
|
|
||||||
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
|
||||||
|
|
||||||
impl = MetaReferenceSafetyImpl(config, deps)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
|
@ -1,57 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
|
||||||
|
|
||||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: clean this up; just remove this type completely
|
|
||||||
class ShieldResponse(BaseModel):
|
|
||||||
is_violation: bool
|
|
||||||
violation_type: Optional[str] = None
|
|
||||||
violation_return_message: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: this is a caller / agent concern
|
|
||||||
class OnViolationAction(Enum):
|
|
||||||
IGNORE = 0
|
|
||||||
WARN = 1
|
|
||||||
RAISE = 2
|
|
||||||
|
|
||||||
|
|
||||||
class ShieldBase(ABC):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
|
||||||
):
|
|
||||||
self.on_violation_action = on_violation_action
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
def message_content_as_str(message: Message) -> str:
|
|
||||||
return interleaved_text_media_as_str(message.content)
|
|
||||||
|
|
||||||
|
|
||||||
class TextShield(ShieldBase):
|
|
||||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
|
||||||
return "\n".join([message_content_as_str(m) for m in messages])
|
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
|
||||||
text = self.convert_messages_to_text(messages)
|
|
||||||
return await self.run_impl(text)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def run_impl(self, text: str) -> ShieldResponse:
|
|
||||||
raise NotImplementedError()
|
|
|
@ -1,145 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import auto, Enum
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardShield(TextShield):
|
|
||||||
class Mode(Enum):
|
|
||||||
INJECTION = auto()
|
|
||||||
JAILBREAK = auto()
|
|
||||||
|
|
||||||
_instances = {}
|
|
||||||
_model_cache = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def instance(
|
|
||||||
model_dir: str,
|
|
||||||
threshold: float = 0.9,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
|
|
||||||
on_violation_action=OnViolationAction.RAISE,
|
|
||||||
) -> "PromptGuardShield":
|
|
||||||
action_value = on_violation_action.value
|
|
||||||
key = (model_dir, threshold, temperature, mode, action_value)
|
|
||||||
if key not in PromptGuardShield._instances:
|
|
||||||
PromptGuardShield._instances[key] = PromptGuardShield(
|
|
||||||
model_dir=model_dir,
|
|
||||||
threshold=threshold,
|
|
||||||
temperature=temperature,
|
|
||||||
mode=mode,
|
|
||||||
on_violation_action=on_violation_action,
|
|
||||||
)
|
|
||||||
return PromptGuardShield._instances[key]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_dir: str,
|
|
||||||
threshold: float = 0.9,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
|
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
|
||||||
):
|
|
||||||
super().__init__(on_violation_action)
|
|
||||||
assert (
|
|
||||||
model_dir is not None
|
|
||||||
), "Must provide a model directory for prompt injection shield"
|
|
||||||
if temperature <= 0:
|
|
||||||
raise ValueError("Temperature must be greater than 0")
|
|
||||||
self.device = "cuda"
|
|
||||||
if PromptGuardShield._model_cache is None:
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
||||||
|
|
||||||
# load model and tokenizer
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(
|
|
||||||
model_dir, device_map=self.device
|
|
||||||
)
|
|
||||||
PromptGuardShield._model_cache = (tokenizer, model)
|
|
||||||
|
|
||||||
self.tokenizer, self.model = PromptGuardShield._model_cache
|
|
||||||
self.temperature = temperature
|
|
||||||
self.threshold = threshold
|
|
||||||
self.mode = mode
|
|
||||||
|
|
||||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
|
||||||
return message_content_as_str(messages[-1])
|
|
||||||
|
|
||||||
async def run_impl(self, text: str) -> ShieldResponse:
|
|
||||||
# run model on messages and return response
|
|
||||||
inputs = self.tokenizer(text, return_tensors="pt")
|
|
||||||
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
logits = outputs[0]
|
|
||||||
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
|
||||||
score_embedded = probabilities[0, 1].item()
|
|
||||||
score_malicious = probabilities[0, 2].item()
|
|
||||||
cprint(
|
|
||||||
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
|
|
||||||
color="magenta",
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.mode == self.Mode.INJECTION and (
|
|
||||||
score_embedded + score_malicious > self.threshold
|
|
||||||
):
|
|
||||||
return ShieldResponse(
|
|
||||||
is_violation=True,
|
|
||||||
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
|
||||||
violation_return_message="Sorry, I cannot do this.",
|
|
||||||
)
|
|
||||||
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
|
|
||||||
return ShieldResponse(
|
|
||||||
is_violation=True,
|
|
||||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
|
||||||
violation_return_message="Sorry, I cannot do this.",
|
|
||||||
)
|
|
||||||
|
|
||||||
return ShieldResponse(
|
|
||||||
is_violation=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class JailbreakShield(PromptGuardShield):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_dir: str,
|
|
||||||
threshold: float = 0.9,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
model_dir=model_dir,
|
|
||||||
threshold=threshold,
|
|
||||||
temperature=temperature,
|
|
||||||
mode=PromptGuardShield.Mode.JAILBREAK,
|
|
||||||
on_violation_action=on_violation_action,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InjectionShield(PromptGuardShield):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_dir: str,
|
|
||||||
threshold: float = 0.9,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
model_dir=model_dir,
|
|
||||||
threshold=threshold,
|
|
||||||
temperature=temperature,
|
|
||||||
mode=PromptGuardShield.Mode.INJECTION,
|
|
||||||
on_violation_action=on_violation_action,
|
|
||||||
)
|
|
|
@ -1,107 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
|
||||||
|
|
||||||
from .base import OnViolationAction, ShieldBase
|
|
||||||
from .config import SafetyConfig
|
|
||||||
from .llama_guard import LlamaGuardShield
|
|
||||||
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
|
|
||||||
|
|
||||||
|
|
||||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
|
||||||
SUPPORTED_SHIELDS = [ShieldType.llama_guard, ShieldType.prompt_guard]
|
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|
||||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
|
||||||
self.config = config
|
|
||||||
self.inference_api = deps[Api.inference]
|
|
||||||
|
|
||||||
self.available_shields = []
|
|
||||||
if config.llama_guard_shield:
|
|
||||||
self.available_shields.append(ShieldType.llama_guard)
|
|
||||||
if config.enable_prompt_guard:
|
|
||||||
self.available_shields.append(ShieldType.prompt_guard)
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
if self.config.enable_prompt_guard:
|
|
||||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
|
||||||
_ = PromptGuardShield.instance(model_dir)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_shield(self, shield: Shield) -> None:
|
|
||||||
if shield.shield_type not in self.available_shields:
|
|
||||||
raise ValueError(f"Shield type {shield.shield_type} not supported")
|
|
||||||
|
|
||||||
async def run_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
messages: List[Message],
|
|
||||||
params: Dict[str, Any] = None,
|
|
||||||
) -> RunShieldResponse:
|
|
||||||
shield = await self.shield_store.get_shield(shield_id)
|
|
||||||
if not shield:
|
|
||||||
raise ValueError(f"Shield {shield_id} not found")
|
|
||||||
|
|
||||||
shield_impl = self.get_shield_impl(shield)
|
|
||||||
|
|
||||||
messages = messages.copy()
|
|
||||||
# some shields like llama-guard require the first message to be a user message
|
|
||||||
# since this might be a tool call, first role might not be user
|
|
||||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
|
||||||
messages[0] = UserMessage(content=messages[0].content)
|
|
||||||
|
|
||||||
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
|
|
||||||
res = await shield_impl.run(messages)
|
|
||||||
violation = None
|
|
||||||
if (
|
|
||||||
res.is_violation
|
|
||||||
and shield_impl.on_violation_action != OnViolationAction.IGNORE
|
|
||||||
):
|
|
||||||
violation = SafetyViolation(
|
|
||||||
violation_level=(
|
|
||||||
ViolationLevel.ERROR
|
|
||||||
if shield_impl.on_violation_action == OnViolationAction.RAISE
|
|
||||||
else ViolationLevel.WARN
|
|
||||||
),
|
|
||||||
user_message=res.violation_return_message,
|
|
||||||
metadata={
|
|
||||||
"violation_type": res.violation_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return RunShieldResponse(violation=violation)
|
|
||||||
|
|
||||||
def get_shield_impl(self, shield: Shield) -> ShieldBase:
|
|
||||||
if shield.shield_type == ShieldType.llama_guard:
|
|
||||||
cfg = self.config.llama_guard_shield
|
|
||||||
return LlamaGuardShield(
|
|
||||||
model=cfg.model,
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
excluded_categories=cfg.excluded_categories,
|
|
||||||
)
|
|
||||||
elif shield.shield_type == ShieldType.prompt_guard:
|
|
||||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
|
||||||
subtype = shield.params.get("prompt_guard_type", "injection")
|
|
||||||
if subtype == "injection":
|
|
||||||
return InjectionShield.instance(model_dir)
|
|
||||||
elif subtype == "jailbreak":
|
|
||||||
return JailbreakShield.instance(model_dir)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown shield type: {shield.shield_type}")
|
|
15
llama_stack/providers/inline/safety/prompt_guard/__init__.py
Normal file
15
llama_stack/providers/inline/safety/prompt_guard/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import PromptGuardConfig # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: PromptGuardConfig, deps):
|
||||||
|
from .prompt_guard import PromptGuardSafetyImpl
|
||||||
|
|
||||||
|
impl = PromptGuardSafetyImpl(config, deps)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
25
llama_stack/providers/inline/safety/prompt_guard/config.py
Normal file
25
llama_stack/providers/inline/safety/prompt_guard/config.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class PromptGuardType(Enum):
|
||||||
|
injection = "injection"
|
||||||
|
jailbreak = "jailbreak"
|
||||||
|
|
||||||
|
|
||||||
|
class PromptGuardConfig(BaseModel):
|
||||||
|
guard_type: str = PromptGuardType.injection.value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@field_validator("guard_type")
|
||||||
|
def validate_guard_type(cls, v):
|
||||||
|
if v not in [t.value for t in PromptGuardType]:
|
||||||
|
raise ValueError(f"Unknown prompt guard type: {v}")
|
||||||
|
return v
|
128
llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py
Normal file
128
llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
|
||||||
|
from .config import PromptGuardConfig, PromptGuardType
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||||
|
|
||||||
|
|
||||||
|
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
|
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||||
|
self.shield = PromptGuardShield(model_dir, self.config)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
|
raise ValueError("Registering dynamic shields is not supported")
|
||||||
|
|
||||||
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
|
return [
|
||||||
|
ShieldDef(
|
||||||
|
identifier=ShieldType.prompt_guard.value,
|
||||||
|
shield_type=ShieldType.prompt_guard.value,
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self,
|
||||||
|
identifier: str,
|
||||||
|
messages: List[Message],
|
||||||
|
params: Dict[str, Any] = None,
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
shield_def = await self.shield_store.get_shield(identifier)
|
||||||
|
if not shield_def:
|
||||||
|
raise ValueError(f"Unknown shield {identifier}")
|
||||||
|
|
||||||
|
return await self.shield.run(messages)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptGuardShield:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dir: str,
|
||||||
|
config: PromptGuardConfig,
|
||||||
|
threshold: float = 0.9,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
model_dir is not None
|
||||||
|
), "Must provide a model directory for prompt injection shield"
|
||||||
|
if temperature <= 0:
|
||||||
|
raise ValueError("Temperature must be greater than 0")
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.temperature = temperature
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
self.device = "cuda"
|
||||||
|
|
||||||
|
# load model and tokenizer
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
|
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_dir, device_map=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||||
|
message = messages[-1]
|
||||||
|
text = interleaved_text_media_as_str(message.content)
|
||||||
|
|
||||||
|
# run model on messages and return response
|
||||||
|
inputs = self.tokenizer(text, return_tensors="pt")
|
||||||
|
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
logits = outputs[0]
|
||||||
|
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
||||||
|
score_embedded = probabilities[0, 1].item()
|
||||||
|
score_malicious = probabilities[0, 2].item()
|
||||||
|
cprint(
|
||||||
|
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
|
||||||
|
color="magenta",
|
||||||
|
)
|
||||||
|
|
||||||
|
violation = None
|
||||||
|
if self.config.guard_type == PromptGuardType.injection.value and (
|
||||||
|
score_embedded + score_malicious > self.threshold
|
||||||
|
):
|
||||||
|
violation = SafetyViolation(
|
||||||
|
violation_level=ViolationLevel.ERROR,
|
||||||
|
user_message="Sorry, I cannot do this.",
|
||||||
|
metadata={
|
||||||
|
"violation_type": f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
self.config.guard_type == PromptGuardType.jailbreak.value
|
||||||
|
and score_malicious > self.threshold
|
||||||
|
):
|
||||||
|
violation = SafetyViolation(
|
||||||
|
violation_level=ViolationLevel.ERROR,
|
||||||
|
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||||
|
violation_return_message="Sorry, I cannot do this.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return RunShieldResponse(violation=violation)
|
|
@ -29,6 +29,42 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.inference,
|
Api.inference,
|
||||||
],
|
],
|
||||||
|
deprecation_warning="Please use the `llama-guard` / `prompt-guard` / `code-scanner` providers instead.",
|
||||||
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.safety,
|
||||||
|
provider_type="llama-guard",
|
||||||
|
pip_packages=[],
|
||||||
|
module="llama_stack.providers.inline.safety.llama_guard",
|
||||||
|
config_class="llama_stack.providers.inline.safety.llama_guard.LlamaGuardConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.inference,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.safety,
|
||||||
|
provider_type="prompt-guard",
|
||||||
|
pip_packages=[
|
||||||
|
"transformers",
|
||||||
|
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.inline.safety.prompt_guard",
|
||||||
|
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.inference,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.safety,
|
||||||
|
provider_type="code-scanner",
|
||||||
|
pip_packages=[
|
||||||
|
"codeshield",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.inline.safety.code_scanner",
|
||||||
|
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.inference,
|
||||||
|
],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
|
@ -48,14 +84,4 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
|
||||||
api=Api.safety,
|
|
||||||
provider_type="meta-reference/codeshield",
|
|
||||||
pip_packages=[
|
|
||||||
"codeshield",
|
|
||||||
],
|
|
||||||
module="llama_stack.providers.inline.safety.meta_reference",
|
|
||||||
config_class="llama_stack.providers.inline.safety.meta_reference.CodeShieldConfig",
|
|
||||||
api_dependencies=[],
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue