Use inference APIs for executing Llama Guard

This commit is contained in:
Ashwin Bharambe 2024-09-25 19:40:49 -07:00
parent 6236634d84
commit 37ca22cda6
3 changed files with 94 additions and 164 deletions

View file

@ -88,10 +88,10 @@ class MetaReferenceSafetyImpl(Safety):
assert ( assert (
cfg is not None cfg is not None
), "Cannot use LlamaGuardShield since not present in config" ), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.model)
return LlamaGuardShield( return LlamaGuardShield(
model_dir=model_dir, model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories, excluded_categories=cfg.excluded_categories,
disable_input_check=cfg.disable_input_check, disable_input_check=cfg.disable_input_check,
disable_output_check=cfg.disable_output_check, disable_output_check=cfg.disable_output_check,

View file

@ -9,14 +9,8 @@ import re
from string import Template from string import Template
from typing import List, Optional from typing import List, Optional
import torch from llama_models.llama3.api.datatypes import * # noqa: F403
from transformers import ( from llama_stack.apis.inference import * # noqa: F403
AutoModelForCausalLM,
AutoTokenizer,
MllamaForConditionalGeneration,
MllamaProcessor
)
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
@ -120,7 +114,8 @@ PROMPT_TEMPLATE = Template(
class LlamaGuardShield(ShieldBase): class LlamaGuardShield(ShieldBase):
def __init__( def __init__(
self, self,
model_dir: str, model: str,
inference_api: Inference,
excluded_categories: List[str] = None, excluded_categories: List[str] = None,
disable_input_check: bool = False, disable_input_check: bool = False,
disable_output_check: bool = False, disable_output_check: bool = False,
@ -128,12 +123,6 @@ class LlamaGuardShield(ShieldBase):
): ):
super().__init__(on_violation_action) super().__init__(on_violation_action)
dtype = torch.bfloat16
self.model_dir = model_dir
self.device = "cuda"
assert self.model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None: if excluded_categories is None:
excluded_categories = [] excluded_categories = []
@ -141,27 +130,12 @@ class LlamaGuardShield(ShieldBase):
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
self.model = model
self.inference_api = inference_api
self.excluded_categories = excluded_categories self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check self.disable_output_check = disable_output_check
torch_dtype = torch.bfloat16
self.model_dir = f"meta-llama/{self.get_model_name()}"
if self.is_lg_vision():
self.model = MllamaForConditionalGeneration.from_pretrained(
self.model_dir, device_map=self.device, torch_dtype=torch_dtype
)
self.processor = MllamaProcessor.from_pretrained(self.model_dir)
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def check_unsafe_response(self, response: str) -> Optional[str]: def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response) match = re.match(r"^unsafe\n(.*)$", response)
if match: if match:
@ -186,127 +160,6 @@ class LlamaGuardShield(ShieldBase):
return final_categories return final_categories
def build_prompt(self, messages: List[Message]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[f"{m.role.capitalize()}: {m.content}" for m in messages]
)
return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(),
categories=categories_str,
conversations=conversations_str,
)
def get_shield_response(self, response: str) -> ShieldResponse:
response = response.strip()
if response == SAFE_RESPONSE:
return ShieldResponse(is_violation=False)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(is_violation=False)
return ShieldResponse(
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
)
raise ValueError(f"Unexpected response: {response}")
def build_mm_prompt(self, messages: List[Message]) -> str:
conversation = []
most_recent_img = None
for m in messages[::-1]:
if isinstance(m.content, str):
conversation.append(
{
"role": m.role,
"content": [{"type": "text", "text": m.content}],
}
)
elif isinstance(m.content, ImageMedia):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = m.content
conversation.append(
{
"role": m.role,
"content": [{"type": "image"}],
}
)
elif isinstance(m.content, list):
content = []
for c in m.content:
if isinstance(c, str):
content.append({"type": "text", "text": c})
elif isinstance(c, ImageMedia):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = c
content.append({"type": "image"})
else:
raise ValueError(f"Unknown content type: {c}")
conversation.append(
{
"role": m.role,
"content": content,
}
)
else:
raise ValueError(f"Unknown content type: {m.content}")
return conversation[::-1], most_recent_img
async def run_lg_mm(self, messages: List[Message]) -> ShieldResponse:
formatted_messages, most_recent_img = self.build_mm_prompt(messages)
raw_image = None
if most_recent_img:
raw_image = interleaved_text_media_localize(most_recent_img)
raw_image = raw_image.image
llama_guard_input_templ_applied = self.processor.apply_chat_template(
formatted_messages,
add_generation_prompt=True,
tokenize=False,
skip_special_tokens=False,
)
inputs = self.processor(
text=llama_guard_input_templ_applied, images=raw_image, return_tensors="pt"
).to(self.device)
output = self.model.generate(**inputs, do_sample=False, max_new_tokens=50)
response = self.processor.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
shield_response = self.get_shield_response(response)
return shield_response
async def run_lg_text(self, messages: List[Message]):
prompt = self.build_prompt(messages)
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
shield_response = self.get_shield_response(response)
return shield_response
def get_model_name(self):
return self.model_dir.split("/")[-1]
def is_lg_vision(self):
model_name = self.get_model_name()
return model_name == LG_3_11B_VISION
def validate_messages(self, messages: List[Message]) -> None: def validate_messages(self, messages: List[Message]) -> None:
if len(messages) == 0: if len(messages) == 0:
raise ValueError("Messages must not be empty") raise ValueError("Messages must not be empty")
@ -334,14 +187,92 @@ class LlamaGuardShield(ShieldBase):
return ShieldResponse( return ShieldResponse(
is_violation=False, is_violation=False,
) )
if self.model == LG_3_11B_VISION:
shield_input_message = self.build_vision_shield_input(messages)
else: else:
shield_input_message = self.build_text_shield_input(messages)
if self.is_lg_vision(): # TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
shield_response = await self.run_lg_mm(messages) async for chunk in self.inference_api.chat_completion(
model=self.model,
else: messages=[shield_input_message],
stream=True,
shield_response = await self.run_lg_text(messages) ):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.progress:
assert isinstance(event.delta, str)
content += event.delta
content = content.strip()
shield_response = self.get_shield_response(content)
return shield_response return shield_response
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
return UserMessage(content=self.build_prompt(messages))
def build_vision_shield_input(self, messages: List[Message]) -> UserMessage:
conversation = []
most_recent_img = None
for m in messages[::-1]:
if isinstance(m.content, str):
conversation.append(m)
elif isinstance(m.content, ImageMedia):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = m.content
conversation.append(m)
elif isinstance(m.content, list):
content = []
for c in m.content:
if isinstance(c, str):
content.append(c)
elif isinstance(c, ImageMedia):
if most_recent_img is None and m.role == Role.user.value:
content.append(c)
else:
raise ValueError(f"Unknown content type: {c}")
conversation.append(UserMessage(content=content))
else:
raise ValueError(f"Unknown content type: {m.content}")
content = []
if most_recent_img is not None:
content.append(most_recent_img)
content.append(self.build_prompt(conversation[::-1]))
return UserMessage(content=content)
def build_prompt(self, messages: List[Message]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[
f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}"
for m in messages
]
)
return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(),
categories=categories_str,
conversations=conversations_str,
)
def get_shield_response(self, response: str) -> ShieldResponse:
response = response.strip()
if response == SAFE_RESPONSE:
return ShieldResponse(is_violation=False)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(is_violation=False)
return ShieldResponse(
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
)
raise ValueError(f"Unexpected response: {response}")

View file

@ -21,10 +21,9 @@ def available_providers() -> List[ProviderSpec]:
api=Api.safety, api=Api.safety,
provider_id="meta-reference", provider_id="meta-reference",
pip_packages=[ pip_packages=[
"accelerate",
"codeshield", "codeshield",
"torch",
"transformers", "transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
], ],
module="llama_stack.providers.impls.meta_reference.safety", module="llama_stack.providers.impls.meta_reference.safety",
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig", config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",