mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-13 16:46:09 +00:00
Use inference APIs for executing Llama Guard (#121)
We should use Inference APIs to execute Llama Guard instead of directly needing to use HuggingFace modeling related code. The actual inference consideration is handled by Inference.
This commit is contained in:
parent
6236634d84
commit
0a3999a9a4
9 changed files with 167 additions and 204 deletions
|
@ -7,12 +7,13 @@
|
|||
from typing import Optional
|
||||
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
from llama_models.sku_list import all_registered_models, resolve_model
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F401, F403
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
model: str = Field(
|
||||
|
@ -27,12 +28,7 @@ class MetaReferenceImplConfig(BaseModel):
|
|||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
m.descriptor()
|
||||
for m in all_registered_models()
|
||||
if m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
or m.core_model_id == CoreModelId.llama_guard_3_8b
|
||||
]
|
||||
permitted_models = supported_inference_models()
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
raise ValueError(
|
||||
|
|
|
@ -52,7 +52,7 @@ def model_checkpoint_dir(model) -> str:
|
|||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoint dir: {checkpoint_dir}."
|
||||
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
|
||||
f"Please download model using `llama download --model-id {model.descriptor()}`"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
|
|
@ -88,10 +88,10 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
assert (
|
||||
cfg is not None
|
||||
), "Cannot use LlamaGuardShield since not present in config"
|
||||
model_dir = resolve_and_get_path(cfg.model)
|
||||
|
||||
return LlamaGuardShield(
|
||||
model_dir=model_dir,
|
||||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
disable_input_check=cfg.disable_input_check,
|
||||
disable_output_check=cfg.disable_output_check,
|
||||
|
|
|
@ -9,17 +9,10 @@ import re
|
|||
from string import Template
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
MllamaForConditionalGeneration,
|
||||
MllamaProcessor
|
||||
)
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
SAFE_RESPONSE = "safe"
|
||||
|
@ -75,20 +68,17 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
|||
CAT_ELECTIONS,
|
||||
]
|
||||
|
||||
# model names
|
||||
LG_3_8B = "Llama-Guard-3-8B"
|
||||
LG_3_1B = "Llama-Guard-3-1B"
|
||||
LG_3_11B_VISION = "Llama-Guard-3-11B-Vision"
|
||||
|
||||
|
||||
MODEL_TO_SAFETY_CATEGORIES_MAP = {
|
||||
LG_3_8B: DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE],
|
||||
LG_3_1B: DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
LG_3_11B_VISION: DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
CoreModelId.llama_guard_3_8b.value: (
|
||||
DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
||||
),
|
||||
CoreModelId.llama_guard_3_1b.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
CoreModelId.llama_guard_3_11b_vision.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
}
|
||||
|
||||
|
||||
PROMPT_TASK = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nTask: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
|
||||
PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
|
||||
|
||||
SAFETY_CATEGORIES = """
|
||||
|
||||
|
@ -109,7 +99,7 @@ $conversations
|
|||
PROMPT_INSTRUCTIONS = """
|
||||
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
|
||||
- First line must read 'safe' or 'unsafe'.
|
||||
- If unsafe, a second line must include a comma-separated list of violated categories. <|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
|
||||
- If unsafe, a second line must include a comma-separated list of violated categories."""
|
||||
|
||||
|
||||
PROMPT_TEMPLATE = Template(
|
||||
|
@ -120,7 +110,8 @@ PROMPT_TEMPLATE = Template(
|
|||
class LlamaGuardShield(ShieldBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
|
@ -128,12 +119,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
):
|
||||
super().__init__(on_violation_action)
|
||||
|
||||
dtype = torch.bfloat16
|
||||
self.model_dir = model_dir
|
||||
self.device = "cuda"
|
||||
|
||||
assert self.model_dir is not None, "Llama Guard model_dir is None"
|
||||
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
|
||||
|
@ -141,27 +126,15 @@ class LlamaGuardShield(ShieldBase):
|
|||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||
|
||||
if model not in MODEL_TO_SAFETY_CATEGORIES_MAP:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
|
||||
self.model = model
|
||||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
self.disable_input_check = disable_input_check
|
||||
self.disable_output_check = disable_output_check
|
||||
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
self.model_dir = f"meta-llama/{self.get_model_name()}"
|
||||
|
||||
if self.is_lg_vision():
|
||||
|
||||
self.model = MllamaForConditionalGeneration.from_pretrained(
|
||||
self.model_dir, device_map=self.device, torch_dtype=torch_dtype
|
||||
)
|
||||
self.processor = MllamaProcessor.from_pretrained(self.model_dir)
|
||||
else:
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_dir, torch_dtype=torch_dtype, device_map=self.device
|
||||
)
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
if match:
|
||||
|
@ -177,7 +150,8 @@ class LlamaGuardShield(ShieldBase):
|
|||
excluded_categories = []
|
||||
|
||||
final_categories = []
|
||||
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.get_model_name()]
|
||||
|
||||
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.model]
|
||||
for cat in all_categories:
|
||||
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
|
||||
if cat_code in excluded_categories:
|
||||
|
@ -186,11 +160,99 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
return final_categories
|
||||
|
||||
def validate_messages(self, messages: List[Message]) -> None:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
if messages[0].role != Role.user.value:
|
||||
raise ValueError("Messages must start with user")
|
||||
|
||||
if len(messages) >= 2 and (
|
||||
messages[0].role == Role.user.value and messages[1].role == Role.user.value
|
||||
):
|
||||
messages = messages[1:]
|
||||
|
||||
for i in range(1, len(messages)):
|
||||
if messages[i].role == messages[i - 1].role:
|
||||
raise ValueError(
|
||||
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}"
|
||||
)
|
||||
return messages
|
||||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||
return ShieldResponse(
|
||||
is_violation=False,
|
||||
)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
shield_input_message = self.build_vision_shield_input(messages)
|
||||
else:
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||
content = ""
|
||||
async for chunk in self.inference_api.chat_completion(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=True,
|
||||
):
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.progress:
|
||||
assert isinstance(event.delta, str)
|
||||
content += event.delta
|
||||
|
||||
content = content.strip()
|
||||
shield_response = self.get_shield_response(content)
|
||||
return shield_response
|
||||
|
||||
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
|
||||
return UserMessage(content=self.build_prompt(messages))
|
||||
|
||||
def build_vision_shield_input(self, messages: List[Message]) -> UserMessage:
|
||||
conversation = []
|
||||
most_recent_img = None
|
||||
|
||||
for m in messages[::-1]:
|
||||
if isinstance(m.content, str):
|
||||
conversation.append(m)
|
||||
elif isinstance(m.content, ImageMedia):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = m.content
|
||||
conversation.append(m)
|
||||
elif isinstance(m.content, list):
|
||||
content = []
|
||||
for c in m.content:
|
||||
if isinstance(c, str):
|
||||
content.append(c)
|
||||
elif isinstance(c, ImageMedia):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = c
|
||||
content.append(c)
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {c}")
|
||||
|
||||
conversation.append(UserMessage(content=content))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {m.content}")
|
||||
|
||||
prompt = []
|
||||
if most_recent_img is not None:
|
||||
prompt.append(most_recent_img)
|
||||
prompt.append(self.build_prompt(conversation[::-1]))
|
||||
|
||||
return UserMessage(content=prompt)
|
||||
|
||||
def build_prompt(self, messages: List[Message]) -> str:
|
||||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
[f"{m.role.capitalize()}: {m.content}" for m in messages]
|
||||
[
|
||||
f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}"
|
||||
for m in messages
|
||||
]
|
||||
)
|
||||
return PROMPT_TEMPLATE.substitute(
|
||||
agent_type=messages[-1].role.capitalize(),
|
||||
|
@ -214,134 +276,3 @@ class LlamaGuardShield(ShieldBase):
|
|||
)
|
||||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
def build_mm_prompt(self, messages: List[Message]) -> str:
|
||||
conversation = []
|
||||
most_recent_img = None
|
||||
|
||||
for m in messages[::-1]:
|
||||
if isinstance(m.content, str):
|
||||
conversation.append(
|
||||
{
|
||||
"role": m.role,
|
||||
"content": [{"type": "text", "text": m.content}],
|
||||
}
|
||||
)
|
||||
elif isinstance(m.content, ImageMedia):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = m.content
|
||||
conversation.append(
|
||||
{
|
||||
"role": m.role,
|
||||
"content": [{"type": "image"}],
|
||||
}
|
||||
)
|
||||
|
||||
elif isinstance(m.content, list):
|
||||
content = []
|
||||
for c in m.content:
|
||||
if isinstance(c, str):
|
||||
content.append({"type": "text", "text": c})
|
||||
elif isinstance(c, ImageMedia):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = c
|
||||
content.append({"type": "image"})
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {c}")
|
||||
|
||||
conversation.append(
|
||||
{
|
||||
"role": m.role,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {m.content}")
|
||||
|
||||
return conversation[::-1], most_recent_img
|
||||
|
||||
async def run_lg_mm(self, messages: List[Message]) -> ShieldResponse:
|
||||
formatted_messages, most_recent_img = self.build_mm_prompt(messages)
|
||||
raw_image = None
|
||||
if most_recent_img:
|
||||
raw_image = interleaved_text_media_localize(most_recent_img)
|
||||
raw_image = raw_image.image
|
||||
llama_guard_input_templ_applied = self.processor.apply_chat_template(
|
||||
formatted_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
inputs = self.processor(
|
||||
text=llama_guard_input_templ_applied, images=raw_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
output = self.model.generate(**inputs, do_sample=False, max_new_tokens=50)
|
||||
response = self.processor.decode(
|
||||
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
||||
)
|
||||
shield_response = self.get_shield_response(response)
|
||||
return shield_response
|
||||
|
||||
async def run_lg_text(self, messages: List[Message]):
|
||||
prompt = self.build_prompt(messages)
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
||||
prompt_len = input_ids.shape[1]
|
||||
output = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=20,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
pad_token_id=0,
|
||||
)
|
||||
generated_tokens = output.sequences[:, prompt_len:]
|
||||
|
||||
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
||||
|
||||
shield_response = self.get_shield_response(response)
|
||||
return shield_response
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_dir.split("/")[-1]
|
||||
|
||||
def is_lg_vision(self):
|
||||
model_name = self.get_model_name()
|
||||
return model_name == LG_3_11B_VISION
|
||||
|
||||
def validate_messages(self, messages: List[Message]) -> None:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
if messages[0].role != Role.user.value:
|
||||
raise ValueError("Messages must start with user")
|
||||
|
||||
if len(messages) >= 2 and (
|
||||
messages[0].role == Role.user.value and messages[1].role == Role.user.value
|
||||
):
|
||||
messages = messages[1:]
|
||||
|
||||
for i in range(1, len(messages)):
|
||||
if messages[i].role == messages[i - 1].role:
|
||||
raise ValueError(
|
||||
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}"
|
||||
)
|
||||
return messages
|
||||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
|
||||
messages = self.validate_messages(messages)
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||
return ShieldResponse(
|
||||
is_violation=False,
|
||||
)
|
||||
else:
|
||||
|
||||
if self.is_lg_vision():
|
||||
|
||||
shield_response = await self.run_lg_mm(messages)
|
||||
|
||||
else:
|
||||
|
||||
shield_response = await self.run_lg_text(messages)
|
||||
|
||||
return shield_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue