This commit is contained in:
Ashwin Bharambe 2024-09-28 15:21:32 -07:00
parent 37ca22cda6
commit 23028e26ff
7 changed files with 83 additions and 47 deletions

View file

@ -13,7 +13,6 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_models.llama3.api.datatypes import * # noqa: F403
SAFE_RESPONSE = "safe"
@ -69,20 +68,17 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_ELECTIONS,
]
# model names
LG_3_8B = "Llama-Guard-3-8B"
LG_3_1B = "Llama-Guard-3-1B"
LG_3_11B_VISION = "Llama-Guard-3-11B-Vision"
MODEL_TO_SAFETY_CATEGORIES_MAP = {
LG_3_8B: DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE],
LG_3_1B: DEFAULT_LG_V3_SAFETY_CATEGORIES,
LG_3_11B_VISION: DEFAULT_LG_V3_SAFETY_CATEGORIES,
CoreModelId.llama_guard_3_8b.value: (
DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
),
CoreModelId.llama_guard_3_1b.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
CoreModelId.llama_guard_3_11b_vision.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
}
PROMPT_TASK = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nTask: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
SAFETY_CATEGORIES = """
@ -103,7 +99,7 @@ $conversations
PROMPT_INSTRUCTIONS = """
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. <|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
- If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_TEMPLATE = Template(
@ -130,6 +126,9 @@ class LlamaGuardShield(ShieldBase):
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
if model not in MODEL_TO_SAFETY_CATEGORIES_MAP:
raise ValueError(f"Unsupported model: {model}")
self.model = model
self.inference_api = inference_api
self.excluded_categories = excluded_categories
@ -151,7 +150,8 @@ class LlamaGuardShield(ShieldBase):
excluded_categories = []
final_categories = []
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.get_model_name()]
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.model]
for cat in all_categories:
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
if cat_code in excluded_categories:
@ -179,7 +179,6 @@ class LlamaGuardShield(ShieldBase):
return messages
async def run(self, messages: List[Message]) -> ShieldResponse:
messages = self.validate_messages(messages)
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(is_violation=False)
@ -188,7 +187,7 @@ class LlamaGuardShield(ShieldBase):
is_violation=False,
)
if self.model == LG_3_11B_VISION:
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
shield_input_message = self.build_vision_shield_input(messages)
else:
shield_input_message = self.build_text_shield_input(messages)
@ -230,6 +229,7 @@ class LlamaGuardShield(ShieldBase):
content.append(c)
elif isinstance(c, ImageMedia):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = c
content.append(c)
else:
raise ValueError(f"Unknown content type: {c}")
@ -238,12 +238,12 @@ class LlamaGuardShield(ShieldBase):
else:
raise ValueError(f"Unknown content type: {m.content}")
content = []
prompt = []
if most_recent_img is not None:
content.append(most_recent_img)
content.append(self.build_prompt(conversation[::-1]))
prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1]))
return UserMessage(content=content)
return UserMessage(content=prompt)
def build_prompt(self, messages: List[Message]) -> str:
categories = self.get_safety_categories()
@ -254,6 +254,7 @@ class LlamaGuardShield(ShieldBase):
for m in messages
]
)
return conversations_str
return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(),
categories=categories_str,