mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Use inference APIs for executing Llama Guard
This commit is contained in:
parent
6236634d84
commit
37ca22cda6
3 changed files with 94 additions and 164 deletions
|
@ -88,10 +88,10 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
assert (
|
assert (
|
||||||
cfg is not None
|
cfg is not None
|
||||||
), "Cannot use LlamaGuardShield since not present in config"
|
), "Cannot use LlamaGuardShield since not present in config"
|
||||||
model_dir = resolve_and_get_path(cfg.model)
|
|
||||||
|
|
||||||
return LlamaGuardShield(
|
return LlamaGuardShield(
|
||||||
model_dir=model_dir,
|
model=cfg.model,
|
||||||
|
inference_api=self.inference_api,
|
||||||
excluded_categories=cfg.excluded_categories,
|
excluded_categories=cfg.excluded_categories,
|
||||||
disable_input_check=cfg.disable_input_check,
|
disable_input_check=cfg.disable_input_check,
|
||||||
disable_output_check=cfg.disable_output_check,
|
disable_output_check=cfg.disable_output_check,
|
||||||
|
|
|
@ -9,14 +9,8 @@ import re
|
||||||
from string import Template
|
from string import Template
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from transformers import (
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoTokenizer,
|
|
||||||
MllamaForConditionalGeneration,
|
|
||||||
MllamaProcessor
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
@ -120,7 +114,8 @@ PROMPT_TEMPLATE = Template(
|
||||||
class LlamaGuardShield(ShieldBase):
|
class LlamaGuardShield(ShieldBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_dir: str,
|
model: str,
|
||||||
|
inference_api: Inference,
|
||||||
excluded_categories: List[str] = None,
|
excluded_categories: List[str] = None,
|
||||||
disable_input_check: bool = False,
|
disable_input_check: bool = False,
|
||||||
disable_output_check: bool = False,
|
disable_output_check: bool = False,
|
||||||
|
@ -128,12 +123,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
):
|
):
|
||||||
super().__init__(on_violation_action)
|
super().__init__(on_violation_action)
|
||||||
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
self.model_dir = model_dir
|
|
||||||
self.device = "cuda"
|
|
||||||
|
|
||||||
assert self.model_dir is not None, "Llama Guard model_dir is None"
|
|
||||||
|
|
||||||
if excluded_categories is None:
|
if excluded_categories is None:
|
||||||
excluded_categories = []
|
excluded_categories = []
|
||||||
|
|
||||||
|
@ -141,27 +130,12 @@ class LlamaGuardShield(ShieldBase):
|
||||||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.inference_api = inference_api
|
||||||
self.excluded_categories = excluded_categories
|
self.excluded_categories = excluded_categories
|
||||||
self.disable_input_check = disable_input_check
|
self.disable_input_check = disable_input_check
|
||||||
self.disable_output_check = disable_output_check
|
self.disable_output_check = disable_output_check
|
||||||
|
|
||||||
torch_dtype = torch.bfloat16
|
|
||||||
|
|
||||||
self.model_dir = f"meta-llama/{self.get_model_name()}"
|
|
||||||
|
|
||||||
if self.is_lg_vision():
|
|
||||||
|
|
||||||
self.model = MllamaForConditionalGeneration.from_pretrained(
|
|
||||||
self.model_dir, device_map=self.device, torch_dtype=torch_dtype
|
|
||||||
)
|
|
||||||
self.processor = MllamaProcessor.from_pretrained(self.model_dir)
|
|
||||||
else:
|
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
self.model_dir, torch_dtype=torch_dtype, device_map=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||||
match = re.match(r"^unsafe\n(.*)$", response)
|
match = re.match(r"^unsafe\n(.*)$", response)
|
||||||
if match:
|
if match:
|
||||||
|
@ -186,127 +160,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
|
|
||||||
return final_categories
|
return final_categories
|
||||||
|
|
||||||
def build_prompt(self, messages: List[Message]) -> str:
|
|
||||||
categories = self.get_safety_categories()
|
|
||||||
categories_str = "\n".join(categories)
|
|
||||||
conversations_str = "\n\n".join(
|
|
||||||
[f"{m.role.capitalize()}: {m.content}" for m in messages]
|
|
||||||
)
|
|
||||||
return PROMPT_TEMPLATE.substitute(
|
|
||||||
agent_type=messages[-1].role.capitalize(),
|
|
||||||
categories=categories_str,
|
|
||||||
conversations=conversations_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_shield_response(self, response: str) -> ShieldResponse:
|
|
||||||
response = response.strip()
|
|
||||||
if response == SAFE_RESPONSE:
|
|
||||||
return ShieldResponse(is_violation=False)
|
|
||||||
unsafe_code = self.check_unsafe_response(response)
|
|
||||||
if unsafe_code:
|
|
||||||
unsafe_code_list = unsafe_code.split(",")
|
|
||||||
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
|
||||||
return ShieldResponse(is_violation=False)
|
|
||||||
return ShieldResponse(
|
|
||||||
is_violation=True,
|
|
||||||
violation_type=unsafe_code,
|
|
||||||
violation_return_message=CANNED_RESPONSE_TEXT,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(f"Unexpected response: {response}")
|
|
||||||
|
|
||||||
def build_mm_prompt(self, messages: List[Message]) -> str:
|
|
||||||
conversation = []
|
|
||||||
most_recent_img = None
|
|
||||||
|
|
||||||
for m in messages[::-1]:
|
|
||||||
if isinstance(m.content, str):
|
|
||||||
conversation.append(
|
|
||||||
{
|
|
||||||
"role": m.role,
|
|
||||||
"content": [{"type": "text", "text": m.content}],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif isinstance(m.content, ImageMedia):
|
|
||||||
if most_recent_img is None and m.role == Role.user.value:
|
|
||||||
most_recent_img = m.content
|
|
||||||
conversation.append(
|
|
||||||
{
|
|
||||||
"role": m.role,
|
|
||||||
"content": [{"type": "image"}],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(m.content, list):
|
|
||||||
content = []
|
|
||||||
for c in m.content:
|
|
||||||
if isinstance(c, str):
|
|
||||||
content.append({"type": "text", "text": c})
|
|
||||||
elif isinstance(c, ImageMedia):
|
|
||||||
if most_recent_img is None and m.role == Role.user.value:
|
|
||||||
most_recent_img = c
|
|
||||||
content.append({"type": "image"})
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown content type: {c}")
|
|
||||||
|
|
||||||
conversation.append(
|
|
||||||
{
|
|
||||||
"role": m.role,
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown content type: {m.content}")
|
|
||||||
|
|
||||||
return conversation[::-1], most_recent_img
|
|
||||||
|
|
||||||
async def run_lg_mm(self, messages: List[Message]) -> ShieldResponse:
|
|
||||||
formatted_messages, most_recent_img = self.build_mm_prompt(messages)
|
|
||||||
raw_image = None
|
|
||||||
if most_recent_img:
|
|
||||||
raw_image = interleaved_text_media_localize(most_recent_img)
|
|
||||||
raw_image = raw_image.image
|
|
||||||
llama_guard_input_templ_applied = self.processor.apply_chat_template(
|
|
||||||
formatted_messages,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=False,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
|
||||||
inputs = self.processor(
|
|
||||||
text=llama_guard_input_templ_applied, images=raw_image, return_tensors="pt"
|
|
||||||
).to(self.device)
|
|
||||||
output = self.model.generate(**inputs, do_sample=False, max_new_tokens=50)
|
|
||||||
response = self.processor.decode(
|
|
||||||
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
|
||||||
)
|
|
||||||
shield_response = self.get_shield_response(response)
|
|
||||||
return shield_response
|
|
||||||
|
|
||||||
async def run_lg_text(self, messages: List[Message]):
|
|
||||||
prompt = self.build_prompt(messages)
|
|
||||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
|
||||||
prompt_len = input_ids.shape[1]
|
|
||||||
output = self.model.generate(
|
|
||||||
input_ids=input_ids,
|
|
||||||
max_new_tokens=20,
|
|
||||||
output_scores=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
pad_token_id=0,
|
|
||||||
)
|
|
||||||
generated_tokens = output.sequences[:, prompt_len:]
|
|
||||||
|
|
||||||
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
|
||||||
|
|
||||||
shield_response = self.get_shield_response(response)
|
|
||||||
return shield_response
|
|
||||||
|
|
||||||
def get_model_name(self):
|
|
||||||
return self.model_dir.split("/")[-1]
|
|
||||||
|
|
||||||
def is_lg_vision(self):
|
|
||||||
model_name = self.get_model_name()
|
|
||||||
return model_name == LG_3_11B_VISION
|
|
||||||
|
|
||||||
def validate_messages(self, messages: List[Message]) -> None:
|
def validate_messages(self, messages: List[Message]) -> None:
|
||||||
if len(messages) == 0:
|
if len(messages) == 0:
|
||||||
raise ValueError("Messages must not be empty")
|
raise ValueError("Messages must not be empty")
|
||||||
|
@ -334,14 +187,92 @@ class LlamaGuardShield(ShieldBase):
|
||||||
return ShieldResponse(
|
return ShieldResponse(
|
||||||
is_violation=False,
|
is_violation=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.model == LG_3_11B_VISION:
|
||||||
|
shield_input_message = self.build_vision_shield_input(messages)
|
||||||
else:
|
else:
|
||||||
|
shield_input_message = self.build_text_shield_input(messages)
|
||||||
|
|
||||||
if self.is_lg_vision():
|
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||||
|
content = ""
|
||||||
shield_response = await self.run_lg_mm(messages)
|
async for chunk in self.inference_api.chat_completion(
|
||||||
|
model=self.model,
|
||||||
else:
|
messages=[shield_input_message],
|
||||||
|
stream=True,
|
||||||
shield_response = await self.run_lg_text(messages)
|
):
|
||||||
|
event = chunk.event
|
||||||
|
if event.event_type == ChatCompletionResponseEventType.progress:
|
||||||
|
assert isinstance(event.delta, str)
|
||||||
|
content += event.delta
|
||||||
|
|
||||||
|
content = content.strip()
|
||||||
|
shield_response = self.get_shield_response(content)
|
||||||
return shield_response
|
return shield_response
|
||||||
|
|
||||||
|
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
|
||||||
|
return UserMessage(content=self.build_prompt(messages))
|
||||||
|
|
||||||
|
def build_vision_shield_input(self, messages: List[Message]) -> UserMessage:
|
||||||
|
conversation = []
|
||||||
|
most_recent_img = None
|
||||||
|
|
||||||
|
for m in messages[::-1]:
|
||||||
|
if isinstance(m.content, str):
|
||||||
|
conversation.append(m)
|
||||||
|
elif isinstance(m.content, ImageMedia):
|
||||||
|
if most_recent_img is None and m.role == Role.user.value:
|
||||||
|
most_recent_img = m.content
|
||||||
|
conversation.append(m)
|
||||||
|
elif isinstance(m.content, list):
|
||||||
|
content = []
|
||||||
|
for c in m.content:
|
||||||
|
if isinstance(c, str):
|
||||||
|
content.append(c)
|
||||||
|
elif isinstance(c, ImageMedia):
|
||||||
|
if most_recent_img is None and m.role == Role.user.value:
|
||||||
|
content.append(c)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown content type: {c}")
|
||||||
|
|
||||||
|
conversation.append(UserMessage(content=content))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown content type: {m.content}")
|
||||||
|
|
||||||
|
content = []
|
||||||
|
if most_recent_img is not None:
|
||||||
|
content.append(most_recent_img)
|
||||||
|
content.append(self.build_prompt(conversation[::-1]))
|
||||||
|
|
||||||
|
return UserMessage(content=content)
|
||||||
|
|
||||||
|
def build_prompt(self, messages: List[Message]) -> str:
|
||||||
|
categories = self.get_safety_categories()
|
||||||
|
categories_str = "\n".join(categories)
|
||||||
|
conversations_str = "\n\n".join(
|
||||||
|
[
|
||||||
|
f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}"
|
||||||
|
for m in messages
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return PROMPT_TEMPLATE.substitute(
|
||||||
|
agent_type=messages[-1].role.capitalize(),
|
||||||
|
categories=categories_str,
|
||||||
|
conversations=conversations_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_shield_response(self, response: str) -> ShieldResponse:
|
||||||
|
response = response.strip()
|
||||||
|
if response == SAFE_RESPONSE:
|
||||||
|
return ShieldResponse(is_violation=False)
|
||||||
|
unsafe_code = self.check_unsafe_response(response)
|
||||||
|
if unsafe_code:
|
||||||
|
unsafe_code_list = unsafe_code.split(",")
|
||||||
|
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||||
|
return ShieldResponse(is_violation=False)
|
||||||
|
return ShieldResponse(
|
||||||
|
is_violation=True,
|
||||||
|
violation_type=unsafe_code,
|
||||||
|
violation_return_message=CANNED_RESPONSE_TEXT,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unexpected response: {response}")
|
||||||
|
|
|
@ -21,10 +21,9 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
provider_id="meta-reference",
|
provider_id="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"accelerate",
|
|
||||||
"codeshield",
|
"codeshield",
|
||||||
"torch",
|
|
||||||
"transformers",
|
"transformers",
|
||||||
|
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.impls.meta_reference.safety",
|
module="llama_stack.providers.impls.meta_reference.safety",
|
||||||
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
|
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue