From d442af0818457d9f7509155bb57b185a03d8d8d8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 25 Sep 2024 11:06:59 -0700 Subject: [PATCH] Add safety impl for llama guard vision --- .../impls/meta_reference/safety/__init__.py | 4 +- .../impls/meta_reference/safety/safety.py | 28 +-- .../safety/shields/llama_guard.py | 226 +++++++++++++----- 3 files changed, 182 insertions(+), 76 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/safety/__init__.py b/llama_stack/providers/impls/meta_reference/safety/__init__.py index ad175ce46..6c686120c 100644 --- a/llama_stack/providers/impls/meta_reference/safety/__init__.py +++ b/llama_stack/providers/impls/meta_reference/safety/__init__.py @@ -7,11 +7,11 @@ from .config import SafetyConfig -async def get_provider_impl(config: SafetyConfig, _deps): +async def get_provider_impl(config: SafetyConfig, deps): from .safety import MetaReferenceSafetyImpl assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" - impl = MetaReferenceSafetyImpl(config) + impl = MetaReferenceSafetyImpl(config, deps) await impl.initialize() return impl diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 6cf8a79d2..3c0426a9e 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -7,8 +7,10 @@ from llama_models.sku_list import resolve_model from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import Api from llama_stack.providers.impls.meta_reference.safety.shields.base import ( OnViolationAction, @@ -34,20 +36,11 @@ def resolve_and_get_path(model_name: str) -> str: class MetaReferenceSafetyImpl(Safety): - def __init__(self, config: SafetyConfig) -> None: + def __init__(self, config: SafetyConfig, deps) -> None: self.config = config + self.inference_api = deps[Api.inference] async def initialize(self) -> None: - shield_cfg = self.config.llama_guard_shield - if shield_cfg is not None: - model_dir = resolve_and_get_path(shield_cfg.model) - _ = LlamaGuardShield.instance( - model_dir=model_dir, - excluded_categories=shield_cfg.excluded_categories, - disable_input_check=shield_cfg.disable_input_check, - disable_output_check=shield_cfg.disable_output_check, - ) - shield_cfg = self.config.prompt_guard_shield if shield_cfg is not None: model_dir = resolve_and_get_path(shield_cfg.model) @@ -91,11 +84,18 @@ class MetaReferenceSafetyImpl(Safety): def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: cfg = self.config if typ == MetaReferenceShieldType.llama_guard: + cfg = cfg.llama_guard_shield assert ( - cfg.llama_guard_shield is not None + cfg is not None ), "Cannot use LlamaGuardShield since not present in config" - model_dir = resolve_and_get_path(cfg.llama_guard_shield.model) - return LlamaGuardShield.instance(model_dir=model_dir) + model_dir = resolve_and_get_path(cfg.model) + + return LlamaGuardShield( + model_dir=model_dir, + excluded_categories=cfg.excluded_categories, + disable_input_check=cfg.disable_input_check, + disable_output_check=cfg.disable_output_check, + ) elif typ == MetaReferenceShieldType.jailbreak_shield: assert ( cfg.prompt_guard_shield is not None diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py index c29361b95..e8c7b3560 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py @@ -10,8 +10,12 @@ from string import Template from typing import List, Optional import torch -from llama_models.llama3.api.datatypes import Message, Role -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + MllamaForConditionalGeneration, + MllamaProcessor, +) from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse @@ -67,10 +71,22 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [ CAT_SELF_HARM, CAT_SEXUAL_CONTENT, CAT_ELECTIONS, - CAT_CODE_INTERPRETER_ABUSE, ] -PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories." +# model names +LG_3_8B = "Llama-Guard-3-8B" +LG_3_1B = "Llama-Guard-3-1B" +LG_3_11B_VISION = "Llama-Guard-3-11B-Vision" + + +MODEL_TO_SAFETY_CATEGORIES_MAP = { + LG_3_8B: DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE], + LG_3_1B: DEFAULT_LG_V3_SAFETY_CATEGORIES, + LG_3_11B_VISION: DEFAULT_LG_V3_SAFETY_CATEGORIES, +} + + +PROMPT_TASK = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nTask: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories." SAFETY_CATEGORIES = """ @@ -91,7 +107,7 @@ $conversations PROMPT_INSTRUCTIONS = """ Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation: - First line must read 'safe' or 'unsafe'. - - If unsafe, a second line must include a comma-separated list of violated categories.""" + - If unsafe, a second line must include a comma-separated list of violated categories. <|eot_id|><|start_header_id|>assistant<|end_header_id|>""" PROMPT_TEMPLATE = Template( @@ -100,38 +116,21 @@ PROMPT_TEMPLATE = Template( class LlamaGuardShield(ShieldBase): - @staticmethod - def instance( - on_violation_action=OnViolationAction.RAISE, - model_dir: str = None, - excluded_categories: List[str] = None, - disable_input_check: bool = False, - disable_output_check: bool = False, - ) -> "LlamaGuardShield": - global _INSTANCE - if _INSTANCE is None: - _INSTANCE = LlamaGuardShield( - on_violation_action, - model_dir, - excluded_categories, - disable_input_check, - disable_output_check, - ) - return _INSTANCE - def __init__( self, - on_violation_action: OnViolationAction = OnViolationAction.RAISE, - model_dir: str = None, + model_dir: str, excluded_categories: List[str] = None, disable_input_check: bool = False, disable_output_check: bool = False, + on_violation_action: OnViolationAction = OnViolationAction.RAISE, ): super().__init__(on_violation_action) dtype = torch.bfloat16 + self.model_dir = model_dir + self.device = "cuda" - assert model_dir is not None, "Llama Guard model_dir is None" + assert self.model_dir is not None, "Llama Guard model_dir is None" if excluded_categories is None: excluded_categories = [] @@ -140,17 +139,24 @@ class LlamaGuardShield(ShieldBase): x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" - self.device = "cuda" self.excluded_categories = excluded_categories self.disable_input_check = disable_input_check self.disable_output_check = disable_output_check - # load model torch_dtype = torch.bfloat16 - self.tokenizer = AutoTokenizer.from_pretrained(model_dir) - self.model = AutoModelForCausalLM.from_pretrained( - model_dir, torch_dtype=torch_dtype, device_map=self.device - ) + + if self.is_lg_vision(): + + self.model = MllamaForConditionalGeneration.from_pretrained( + self.model_dir, device_map=self.device, torch_dtype=torch_dtype + ) + self.processor = MllamaProcessor.from_pretrained(self.model_dir) + else: + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) + self.model = AutoModelForCausalLM.from_pretrained( + self.model_dir, torch_dtype=torch_dtype, device_map=self.device + ) def check_unsafe_response(self, response: str) -> Optional[str]: match = re.match(r"^unsafe\n(.*)$", response) @@ -166,14 +172,15 @@ class LlamaGuardShield(ShieldBase): if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()): excluded_categories = [] - categories = [] - for cat in DEFAULT_LG_V3_SAFETY_CATEGORIES: + final_categories = [] + all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.get_model_name()] + for cat in all_categories: cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat] if cat_code in excluded_categories: continue - categories.append(f"{cat_code}: {cat}.") + final_categories.append(f"{cat_code}: {cat}.") - return categories + return final_categories def build_prompt(self, messages: List[Message]) -> str: categories = self.get_safety_categories() @@ -188,6 +195,7 @@ class LlamaGuardShield(ShieldBase): ) def get_shield_response(self, response: str) -> ShieldResponse: + response = response.strip() if response == SAFE_RESPONSE: return ShieldResponse(is_violation=False) unsafe_code = self.check_unsafe_response(response) @@ -203,7 +211,119 @@ class LlamaGuardShield(ShieldBase): raise ValueError(f"Unexpected response: {response}") + def build_mm_prompt(self, messages: List[Message]) -> str: + conversation = [] + most_recent_img = None + + for m in messages[::-1]: + if isinstance(m.content, str): + conversation.append( + { + "role": m.role, + "content": [{"type": "text", "text": m.content}], + } + ) + elif isinstance(m.content, ImageMedia): + if most_recent_img is None and m.role == Role.user.value: + most_recent_img = m.content + conversation.append( + { + "role": m.role, + "content": [{"type": "image"}], + } + ) + + elif isinstance(m.content, list): + content = [] + for c in m.content: + if isinstance(c, str): + content.append({"type": "text", "text": c}) + elif isinstance(c, ImageMedia): + if most_recent_img is None and m.role == Role.user.value: + most_recent_img = c + content.append({"type": "image"}) + else: + raise ValueError(f"Unknown content type: {c}") + + conversation.append( + { + "role": m.role, + "content": content, + } + ) + else: + raise ValueError(f"Unknown content type: {m.content}") + + return conversation[::-1], most_recent_img + + async def run_lg_mm(self, messages: List[Message]) -> ShieldResponse: + formatted_messages, most_recent_img = self.build_mm_prompt(messages) + raw_image = None + if most_recent_img: + raw_image = interleaved_text_media_localize(most_recent_img) + raw_image = raw_image.image + llama_guard_input_templ_applied = self.processor.apply_chat_template( + formatted_messages, + add_generation_prompt=True, + tokenize=False, + skip_special_tokens=False, + ) + inputs = self.processor( + text=llama_guard_input_templ_applied, images=raw_image, return_tensors="pt" + ).to(self.device) + output = self.model.generate(**inputs, do_sample=False, max_new_tokens=50) + response = self.processor.decode( + output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True + ) + shield_response = self.get_shield_response(response) + return shield_response + + async def run_lg_text(self, messages: List[Message]): + prompt = self.build_prompt(messages) + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) + prompt_len = input_ids.shape[1] + output = self.model.generate( + input_ids=input_ids, + max_new_tokens=20, + output_scores=True, + return_dict_in_generate=True, + pad_token_id=0, + ) + generated_tokens = output.sequences[:, prompt_len:] + + response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) + + shield_response = self.get_shield_response(response) + return shield_response + + def get_model_name(self): + return self.model_dir.split("/")[-1] + + def is_lg_vision(self): + model_name = self.get_model_name() + return model_name == LG_3_11B_VISION + + def validate_messages(self, messages: List[Message]) -> None: + if len(messages) == 0: + raise ValueError("Messages must not be empty") + if messages[0].role != Role.user.value: + raise ValueError("Messages must start with user") + + if len(messages) >= 2 and ( + messages[0].role == Role.user.value and messages[1].role == Role.user.value + ): + messages = messages[1:] + + for i in range(1, len(messages)): + if messages[i].role == messages[i - 1].role: + raise ValueError( + f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}" + ) + return messages + async def run(self, messages: List[Message]) -> ShieldResponse: + + messages = self.validate_messages(messages) if self.disable_input_check and messages[-1].role == Role.user.value: return ShieldResponse(is_violation=False) elif self.disable_output_check and messages[-1].role == Role.assistant.value: @@ -211,27 +331,13 @@ class LlamaGuardShield(ShieldBase): is_violation=False, ) else: - prompt = self.build_prompt(messages) - llama_guard_input = { - "role": "user", - "content": prompt, - } - input_ids = self.tokenizer.apply_chat_template( - [llama_guard_input], return_tensors="pt", tokenize=True - ).to(self.device) - prompt_len = input_ids.shape[1] - output = self.model.generate( - input_ids=input_ids, - max_new_tokens=20, - output_scores=True, - return_dict_in_generate=True, - pad_token_id=0, - ) - generated_tokens = output.sequences[:, prompt_len:] - response = self.tokenizer.decode( - generated_tokens[0], skip_special_tokens=True - ) - response = response.strip() - shield_response = self.get_shield_response(response) - return shield_response + if self.is_lg_vision(): + + shield_response = await self.run_lg_mm(messages) + + else: + + shield_response = await self.run_lg_text(messages) + + return shield_response