From 37ca22cda6e86a162e5808c8574b77c5c7e7c560 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 25 Sep 2024 19:40:49 -0700 Subject: [PATCH] Use inference APIs for executing Llama Guard --- .../impls/meta_reference/safety/safety.py | 4 +- .../safety/shields/llama_guard.py | 251 +++++++----------- llama_stack/providers/registry/safety.py | 3 +- 3 files changed, 94 insertions(+), 164 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 3c0426a9e..6bb851596 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -88,10 +88,10 @@ class MetaReferenceSafetyImpl(Safety): assert ( cfg is not None ), "Cannot use LlamaGuardShield since not present in config" - model_dir = resolve_and_get_path(cfg.model) return LlamaGuardShield( - model_dir=model_dir, + model=cfg.model, + inference_api=self.inference_api, excluded_categories=cfg.excluded_categories, disable_input_check=cfg.disable_input_check, disable_output_check=cfg.disable_output_check, diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py index 5ee562179..00800c3d9 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py @@ -9,14 +9,8 @@ import re from string import Template from typing import List, Optional -import torch -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - MllamaForConditionalGeneration, - MllamaProcessor -) - +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -120,7 +114,8 @@ PROMPT_TEMPLATE = Template( class LlamaGuardShield(ShieldBase): def __init__( self, - model_dir: str, + model: str, + inference_api: Inference, excluded_categories: List[str] = None, disable_input_check: bool = False, disable_output_check: bool = False, @@ -128,12 +123,6 @@ class LlamaGuardShield(ShieldBase): ): super().__init__(on_violation_action) - dtype = torch.bfloat16 - self.model_dir = model_dir - self.device = "cuda" - - assert self.model_dir is not None, "Llama Guard model_dir is None" - if excluded_categories is None: excluded_categories = [] @@ -141,27 +130,12 @@ class LlamaGuardShield(ShieldBase): x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" + self.model = model + self.inference_api = inference_api self.excluded_categories = excluded_categories self.disable_input_check = disable_input_check self.disable_output_check = disable_output_check - torch_dtype = torch.bfloat16 - - self.model_dir = f"meta-llama/{self.get_model_name()}" - - if self.is_lg_vision(): - - self.model = MllamaForConditionalGeneration.from_pretrained( - self.model_dir, device_map=self.device, torch_dtype=torch_dtype - ) - self.processor = MllamaProcessor.from_pretrained(self.model_dir) - else: - - self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) - self.model = AutoModelForCausalLM.from_pretrained( - self.model_dir, torch_dtype=torch_dtype, device_map=self.device - ) - def check_unsafe_response(self, response: str) -> Optional[str]: match = re.match(r"^unsafe\n(.*)$", response) if match: @@ -186,127 +160,6 @@ class LlamaGuardShield(ShieldBase): return final_categories - def build_prompt(self, messages: List[Message]) -> str: - categories = self.get_safety_categories() - categories_str = "\n".join(categories) - conversations_str = "\n\n".join( - [f"{m.role.capitalize()}: {m.content}" for m in messages] - ) - return PROMPT_TEMPLATE.substitute( - agent_type=messages[-1].role.capitalize(), - categories=categories_str, - conversations=conversations_str, - ) - - def get_shield_response(self, response: str) -> ShieldResponse: - response = response.strip() - if response == SAFE_RESPONSE: - return ShieldResponse(is_violation=False) - unsafe_code = self.check_unsafe_response(response) - if unsafe_code: - unsafe_code_list = unsafe_code.split(",") - if set(unsafe_code_list).issubset(set(self.excluded_categories)): - return ShieldResponse(is_violation=False) - return ShieldResponse( - is_violation=True, - violation_type=unsafe_code, - violation_return_message=CANNED_RESPONSE_TEXT, - ) - - raise ValueError(f"Unexpected response: {response}") - - def build_mm_prompt(self, messages: List[Message]) -> str: - conversation = [] - most_recent_img = None - - for m in messages[::-1]: - if isinstance(m.content, str): - conversation.append( - { - "role": m.role, - "content": [{"type": "text", "text": m.content}], - } - ) - elif isinstance(m.content, ImageMedia): - if most_recent_img is None and m.role == Role.user.value: - most_recent_img = m.content - conversation.append( - { - "role": m.role, - "content": [{"type": "image"}], - } - ) - - elif isinstance(m.content, list): - content = [] - for c in m.content: - if isinstance(c, str): - content.append({"type": "text", "text": c}) - elif isinstance(c, ImageMedia): - if most_recent_img is None and m.role == Role.user.value: - most_recent_img = c - content.append({"type": "image"}) - else: - raise ValueError(f"Unknown content type: {c}") - - conversation.append( - { - "role": m.role, - "content": content, - } - ) - else: - raise ValueError(f"Unknown content type: {m.content}") - - return conversation[::-1], most_recent_img - - async def run_lg_mm(self, messages: List[Message]) -> ShieldResponse: - formatted_messages, most_recent_img = self.build_mm_prompt(messages) - raw_image = None - if most_recent_img: - raw_image = interleaved_text_media_localize(most_recent_img) - raw_image = raw_image.image - llama_guard_input_templ_applied = self.processor.apply_chat_template( - formatted_messages, - add_generation_prompt=True, - tokenize=False, - skip_special_tokens=False, - ) - inputs = self.processor( - text=llama_guard_input_templ_applied, images=raw_image, return_tensors="pt" - ).to(self.device) - output = self.model.generate(**inputs, do_sample=False, max_new_tokens=50) - response = self.processor.decode( - output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True - ) - shield_response = self.get_shield_response(response) - return shield_response - - async def run_lg_text(self, messages: List[Message]): - prompt = self.build_prompt(messages) - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) - prompt_len = input_ids.shape[1] - output = self.model.generate( - input_ids=input_ids, - max_new_tokens=20, - output_scores=True, - return_dict_in_generate=True, - pad_token_id=0, - ) - generated_tokens = output.sequences[:, prompt_len:] - - response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) - - shield_response = self.get_shield_response(response) - return shield_response - - def get_model_name(self): - return self.model_dir.split("/")[-1] - - def is_lg_vision(self): - model_name = self.get_model_name() - return model_name == LG_3_11B_VISION - def validate_messages(self, messages: List[Message]) -> None: if len(messages) == 0: raise ValueError("Messages must not be empty") @@ -334,14 +187,92 @@ class LlamaGuardShield(ShieldBase): return ShieldResponse( is_violation=False, ) + + if self.model == LG_3_11B_VISION: + shield_input_message = self.build_vision_shield_input(messages) else: + shield_input_message = self.build_text_shield_input(messages) - if self.is_lg_vision(): - - shield_response = await self.run_lg_mm(messages) - - else: - - shield_response = await self.run_lg_text(messages) + # TODO: llama-stack inference protocol has issues with non-streaming inference code + content = "" + async for chunk in self.inference_api.chat_completion( + model=self.model, + messages=[shield_input_message], + stream=True, + ): + event = chunk.event + if event.event_type == ChatCompletionResponseEventType.progress: + assert isinstance(event.delta, str) + content += event.delta + content = content.strip() + shield_response = self.get_shield_response(content) return shield_response + + def build_text_shield_input(self, messages: List[Message]) -> UserMessage: + return UserMessage(content=self.build_prompt(messages)) + + def build_vision_shield_input(self, messages: List[Message]) -> UserMessage: + conversation = [] + most_recent_img = None + + for m in messages[::-1]: + if isinstance(m.content, str): + conversation.append(m) + elif isinstance(m.content, ImageMedia): + if most_recent_img is None and m.role == Role.user.value: + most_recent_img = m.content + conversation.append(m) + elif isinstance(m.content, list): + content = [] + for c in m.content: + if isinstance(c, str): + content.append(c) + elif isinstance(c, ImageMedia): + if most_recent_img is None and m.role == Role.user.value: + content.append(c) + else: + raise ValueError(f"Unknown content type: {c}") + + conversation.append(UserMessage(content=content)) + else: + raise ValueError(f"Unknown content type: {m.content}") + + content = [] + if most_recent_img is not None: + content.append(most_recent_img) + content.append(self.build_prompt(conversation[::-1])) + + return UserMessage(content=content) + + def build_prompt(self, messages: List[Message]) -> str: + categories = self.get_safety_categories() + categories_str = "\n".join(categories) + conversations_str = "\n\n".join( + [ + f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}" + for m in messages + ] + ) + return PROMPT_TEMPLATE.substitute( + agent_type=messages[-1].role.capitalize(), + categories=categories_str, + conversations=conversations_str, + ) + + def get_shield_response(self, response: str) -> ShieldResponse: + response = response.strip() + if response == SAFE_RESPONSE: + return ShieldResponse(is_violation=False) + unsafe_code = self.check_unsafe_response(response) + if unsafe_code: + unsafe_code_list = unsafe_code.split(",") + if set(unsafe_code_list).issubset(set(self.excluded_categories)): + return ShieldResponse(is_violation=False) + return ShieldResponse( + is_violation=True, + violation_type=unsafe_code, + violation_return_message=CANNED_RESPONSE_TEXT, + ) + + raise ValueError(f"Unexpected response: {response}") diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index ac14eaaac..e0022f02b 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -21,10 +21,9 @@ def available_providers() -> List[ProviderSpec]: api=Api.safety, provider_id="meta-reference", pip_packages=[ - "accelerate", "codeshield", - "torch", "transformers", + "torch --index-url https://download.pytorch.org/whl/cpu", ], module="llama_stack.providers.impls.meta_reference.safety", config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",