mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 20:27:35 +00:00
Merge branch 'meta-llama:main' into main
This commit is contained in:
commit
cd64371b2e
28 changed files with 286 additions and 283 deletions
|
@ -7,12 +7,13 @@
|
|||
from typing import Optional
|
||||
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
from llama_models.sku_list import all_registered_models, resolve_model
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F401, F403
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
model: str = Field(
|
||||
|
@ -27,12 +28,7 @@ class MetaReferenceImplConfig(BaseModel):
|
|||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
m.descriptor()
|
||||
for m in all_registered_models()
|
||||
if m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
or m.core_model_id == CoreModelId.llama_guard_3_8b
|
||||
]
|
||||
permitted_models = supported_inference_models()
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
raise ValueError(
|
||||
|
|
|
@ -52,7 +52,7 @@ def model_checkpoint_dir(model) -> str:
|
|||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoint dir: {checkpoint_dir}."
|
||||
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
|
||||
f"Please download model using `llama download --model-id {model.descriptor()}`"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
|
|
@ -14,6 +14,10 @@ import torch
|
|||
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from llama_models.llama3.api.model import Transformer, TransformerBlock
|
||||
|
||||
from termcolor import cprint
|
||||
from torch import Tensor
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.apis.inference.config import (
|
||||
|
@ -21,9 +25,6 @@ from llama_stack.apis.inference.config import (
|
|||
MetaReferenceImplConfig,
|
||||
)
|
||||
|
||||
from termcolor import cprint
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def is_fbgemm_available() -> bool:
|
||||
try:
|
||||
|
|
|
@ -88,10 +88,10 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
assert (
|
||||
cfg is not None
|
||||
), "Cannot use LlamaGuardShield since not present in config"
|
||||
model_dir = resolve_and_get_path(cfg.model)
|
||||
|
||||
return LlamaGuardShield(
|
||||
model_dir=model_dir,
|
||||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
disable_input_check=cfg.disable_input_check,
|
||||
disable_output_check=cfg.disable_output_check,
|
||||
|
|
|
@ -9,17 +9,10 @@ import re
|
|||
from string import Template
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
MllamaForConditionalGeneration,
|
||||
MllamaProcessor
|
||||
)
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
SAFE_RESPONSE = "safe"
|
||||
|
@ -75,20 +68,17 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
|||
CAT_ELECTIONS,
|
||||
]
|
||||
|
||||
# model names
|
||||
LG_3_8B = "Llama-Guard-3-8B"
|
||||
LG_3_1B = "Llama-Guard-3-1B"
|
||||
LG_3_11B_VISION = "Llama-Guard-3-11B-Vision"
|
||||
|
||||
|
||||
MODEL_TO_SAFETY_CATEGORIES_MAP = {
|
||||
LG_3_8B: DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE],
|
||||
LG_3_1B: DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
LG_3_11B_VISION: DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
CoreModelId.llama_guard_3_8b.value: (
|
||||
DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
||||
),
|
||||
CoreModelId.llama_guard_3_1b.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
CoreModelId.llama_guard_3_11b_vision.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
}
|
||||
|
||||
|
||||
PROMPT_TASK = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nTask: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
|
||||
PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
|
||||
|
||||
SAFETY_CATEGORIES = """
|
||||
|
||||
|
@ -109,7 +99,7 @@ $conversations
|
|||
PROMPT_INSTRUCTIONS = """
|
||||
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
|
||||
- First line must read 'safe' or 'unsafe'.
|
||||
- If unsafe, a second line must include a comma-separated list of violated categories. <|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
|
||||
- If unsafe, a second line must include a comma-separated list of violated categories."""
|
||||
|
||||
|
||||
PROMPT_TEMPLATE = Template(
|
||||
|
@ -120,7 +110,8 @@ PROMPT_TEMPLATE = Template(
|
|||
class LlamaGuardShield(ShieldBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
|
@ -128,12 +119,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
):
|
||||
super().__init__(on_violation_action)
|
||||
|
||||
dtype = torch.bfloat16
|
||||
self.model_dir = model_dir
|
||||
self.device = "cuda"
|
||||
|
||||
assert self.model_dir is not None, "Llama Guard model_dir is None"
|
||||
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
|
||||
|
@ -141,27 +126,15 @@ class LlamaGuardShield(ShieldBase):
|
|||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||
|
||||
if model not in MODEL_TO_SAFETY_CATEGORIES_MAP:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
|
||||
self.model = model
|
||||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
self.disable_input_check = disable_input_check
|
||||
self.disable_output_check = disable_output_check
|
||||
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
self.model_dir = f"meta-llama/{self.get_model_name()}"
|
||||
|
||||
if self.is_lg_vision():
|
||||
|
||||
self.model = MllamaForConditionalGeneration.from_pretrained(
|
||||
self.model_dir, device_map=self.device, torch_dtype=torch_dtype
|
||||
)
|
||||
self.processor = MllamaProcessor.from_pretrained(self.model_dir)
|
||||
else:
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_dir, torch_dtype=torch_dtype, device_map=self.device
|
||||
)
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
if match:
|
||||
|
@ -177,7 +150,8 @@ class LlamaGuardShield(ShieldBase):
|
|||
excluded_categories = []
|
||||
|
||||
final_categories = []
|
||||
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.get_model_name()]
|
||||
|
||||
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.model]
|
||||
for cat in all_categories:
|
||||
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
|
||||
if cat_code in excluded_categories:
|
||||
|
@ -186,11 +160,99 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
return final_categories
|
||||
|
||||
def validate_messages(self, messages: List[Message]) -> None:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
if messages[0].role != Role.user.value:
|
||||
raise ValueError("Messages must start with user")
|
||||
|
||||
if len(messages) >= 2 and (
|
||||
messages[0].role == Role.user.value and messages[1].role == Role.user.value
|
||||
):
|
||||
messages = messages[1:]
|
||||
|
||||
for i in range(1, len(messages)):
|
||||
if messages[i].role == messages[i - 1].role:
|
||||
raise ValueError(
|
||||
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}"
|
||||
)
|
||||
return messages
|
||||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||
return ShieldResponse(
|
||||
is_violation=False,
|
||||
)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
shield_input_message = self.build_vision_shield_input(messages)
|
||||
else:
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||
content = ""
|
||||
async for chunk in self.inference_api.chat_completion(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=True,
|
||||
):
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.progress:
|
||||
assert isinstance(event.delta, str)
|
||||
content += event.delta
|
||||
|
||||
content = content.strip()
|
||||
shield_response = self.get_shield_response(content)
|
||||
return shield_response
|
||||
|
||||
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
|
||||
return UserMessage(content=self.build_prompt(messages))
|
||||
|
||||
def build_vision_shield_input(self, messages: List[Message]) -> UserMessage:
|
||||
conversation = []
|
||||
most_recent_img = None
|
||||
|
||||
for m in messages[::-1]:
|
||||
if isinstance(m.content, str):
|
||||
conversation.append(m)
|
||||
elif isinstance(m.content, ImageMedia):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = m.content
|
||||
conversation.append(m)
|
||||
elif isinstance(m.content, list):
|
||||
content = []
|
||||
for c in m.content:
|
||||
if isinstance(c, str):
|
||||
content.append(c)
|
||||
elif isinstance(c, ImageMedia):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = c
|
||||
content.append(c)
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {c}")
|
||||
|
||||
conversation.append(UserMessage(content=content))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {m.content}")
|
||||
|
||||
prompt = []
|
||||
if most_recent_img is not None:
|
||||
prompt.append(most_recent_img)
|
||||
prompt.append(self.build_prompt(conversation[::-1]))
|
||||
|
||||
return UserMessage(content=prompt)
|
||||
|
||||
def build_prompt(self, messages: List[Message]) -> str:
|
||||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
[f"{m.role.capitalize()}: {m.content}" for m in messages]
|
||||
[
|
||||
f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}"
|
||||
for m in messages
|
||||
]
|
||||
)
|
||||
return PROMPT_TEMPLATE.substitute(
|
||||
agent_type=messages[-1].role.capitalize(),
|
||||
|
@ -214,134 +276,3 @@ class LlamaGuardShield(ShieldBase):
|
|||
)
|
||||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
def build_mm_prompt(self, messages: List[Message]) -> str:
|
||||
conversation = []
|
||||
most_recent_img = None
|
||||
|
||||
for m in messages[::-1]:
|
||||
if isinstance(m.content, str):
|
||||
conversation.append(
|
||||
{
|
||||
"role": m.role,
|
||||
"content": [{"type": "text", "text": m.content}],
|
||||
}
|
||||
)
|
||||
elif isinstance(m.content, ImageMedia):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = m.content
|
||||
conversation.append(
|
||||
{
|
||||
"role": m.role,
|
||||
"content": [{"type": "image"}],
|
||||
}
|
||||
)
|
||||
|
||||
elif isinstance(m.content, list):
|
||||
content = []
|
||||
for c in m.content:
|
||||
if isinstance(c, str):
|
||||
content.append({"type": "text", "text": c})
|
||||
elif isinstance(c, ImageMedia):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = c
|
||||
content.append({"type": "image"})
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {c}")
|
||||
|
||||
conversation.append(
|
||||
{
|
||||
"role": m.role,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {m.content}")
|
||||
|
||||
return conversation[::-1], most_recent_img
|
||||
|
||||
async def run_lg_mm(self, messages: List[Message]) -> ShieldResponse:
|
||||
formatted_messages, most_recent_img = self.build_mm_prompt(messages)
|
||||
raw_image = None
|
||||
if most_recent_img:
|
||||
raw_image = interleaved_text_media_localize(most_recent_img)
|
||||
raw_image = raw_image.image
|
||||
llama_guard_input_templ_applied = self.processor.apply_chat_template(
|
||||
formatted_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
inputs = self.processor(
|
||||
text=llama_guard_input_templ_applied, images=raw_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
output = self.model.generate(**inputs, do_sample=False, max_new_tokens=50)
|
||||
response = self.processor.decode(
|
||||
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
||||
)
|
||||
shield_response = self.get_shield_response(response)
|
||||
return shield_response
|
||||
|
||||
async def run_lg_text(self, messages: List[Message]):
|
||||
prompt = self.build_prompt(messages)
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
||||
prompt_len = input_ids.shape[1]
|
||||
output = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=20,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
pad_token_id=0,
|
||||
)
|
||||
generated_tokens = output.sequences[:, prompt_len:]
|
||||
|
||||
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
||||
|
||||
shield_response = self.get_shield_response(response)
|
||||
return shield_response
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_dir.split("/")[-1]
|
||||
|
||||
def is_lg_vision(self):
|
||||
model_name = self.get_model_name()
|
||||
return model_name == LG_3_11B_VISION
|
||||
|
||||
def validate_messages(self, messages: List[Message]) -> None:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
if messages[0].role != Role.user.value:
|
||||
raise ValueError("Messages must start with user")
|
||||
|
||||
if len(messages) >= 2 and (
|
||||
messages[0].role == Role.user.value and messages[1].role == Role.user.value
|
||||
):
|
||||
messages = messages[1:]
|
||||
|
||||
for i in range(1, len(messages)):
|
||||
if messages[i].role == messages[i - 1].role:
|
||||
raise ValueError(
|
||||
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}"
|
||||
)
|
||||
return messages
|
||||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
|
||||
messages = self.validate_messages(messages)
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||
return ShieldResponse(
|
||||
is_violation=False,
|
||||
)
|
||||
else:
|
||||
|
||||
if self.is_lg_vision():
|
||||
|
||||
shield_response = await self.run_lg_mm(messages)
|
||||
|
||||
else:
|
||||
|
||||
shield_response = await self.run_lg_text(messages)
|
||||
|
||||
return shield_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue