mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 03:32:26 +00:00
feat: Add open ai compatible moderations api
This commit is contained in:
parent
0caef40e0d
commit
c89fb40082
6 changed files with 549 additions and 0 deletions
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from string import Template
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ from llama_stack.apis.safety import (
|
|||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.models.llama.datatypes import Role
|
||||
|
|
@ -67,6 +69,31 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
|
|||
CAT_ELECTIONS: "S13",
|
||||
CAT_CODE_INTERPRETER_ABUSE: "S14",
|
||||
}
|
||||
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
|
||||
|
||||
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
|
||||
"violence": [CAT_VIOLENT_CRIMES],
|
||||
"violence/graphic": [CAT_VIOLENT_CRIMES],
|
||||
"harassment": [CAT_CHILD_EXPLOITATION],
|
||||
"harassment/threatening": [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
|
||||
"hate": [CAT_HATE],
|
||||
"hate/threatening": [CAT_HATE, CAT_VIOLENT_CRIMES],
|
||||
"illicit": [CAT_NON_VIOLENT_CRIMES],
|
||||
"illicit/violent": [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
|
||||
"sexual": [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
|
||||
"sexual/minors": [CAT_CHILD_EXPLOITATION],
|
||||
"self-harm": [CAT_SELF_HARM],
|
||||
"self-harm/intent": [CAT_SELF_HARM],
|
||||
"self-harm/instructions": [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
|
||||
# These are custom categories that are not in the OpenAI moderation categories
|
||||
"custom/defamation": [CAT_DEFAMATION],
|
||||
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],
|
||||
"custom/privacy_violation": [CAT_PRIVACY],
|
||||
"custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY],
|
||||
"custom/weapons": [CAT_INDISCRIMINATE_WEAPONS],
|
||||
"custom/elections": [CAT_ELECTIONS],
|
||||
"custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE],
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
||||
|
|
@ -195,6 +222,45 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
return await impl.run(messages)
|
||||
|
||||
|
||||
async def create(
|
||||
self,
|
||||
input: str | list[str],
|
||||
model: str | None = None, # To replace with default model for llama-guard
|
||||
) -> ModerationObject:
|
||||
if model is None:
|
||||
raise ValueError("Model cannot be None")
|
||||
if not input or len(input) == 0:
|
||||
raise ValueError("Input cannot be empty")
|
||||
if isinstance(input, list):
|
||||
messages = input.copy()
|
||||
else:
|
||||
messages = [input]
|
||||
|
||||
# convert to user messages format with role
|
||||
messages = [UserMessage(content=m) for m in messages]
|
||||
|
||||
# Use the inference API's model resolution instead of hardcoded mappings
|
||||
# This allows the shield to work with any registered model
|
||||
# Determine safety categories based on the model type
|
||||
# For known Llama Guard models, use specific categories
|
||||
if model in LLAMA_GUARD_MODEL_IDS:
|
||||
# Use the mapped model for categories but the original model_id for inference
|
||||
mapped_model = LLAMA_GUARD_MODEL_IDS[model]
|
||||
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
|
||||
else:
|
||||
# For unknown models, use default Llama Guard 3 8B categories
|
||||
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
||||
|
||||
impl = LlamaGuardShield(
|
||||
model=model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
safety_categories=safety_categories,
|
||||
)
|
||||
|
||||
return await impl.run_create(messages)
|
||||
|
||||
|
||||
class LlamaGuardShield:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -340,3 +406,94 @@ class LlamaGuardShield:
|
|||
)
|
||||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
async def run_create(self, messages: list[Message]) -> ModerationObject:
|
||||
# TODO: Add Image based support for OpenAI Moderations
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
)
|
||||
content = response.completion_message.content
|
||||
content = content.strip()
|
||||
return self.get_moderation_object(content)
|
||||
|
||||
def create_safe_moderation_object(self, model: str) -> ModerationObject:
|
||||
"""Create a ModerationObject for safe content."""
|
||||
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
|
||||
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 0.0)
|
||||
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
|
||||
|
||||
return ModerationObject(
|
||||
id=f"modr-{uuid.uuid4()}",
|
||||
model=model,
|
||||
results=[
|
||||
ModerationObjectResults(
|
||||
flagged=False,
|
||||
categories=categories,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
category_scores=category_scores,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def create_unsafe_moderation_object(self, model: str, unsafe_code: str) -> ModerationObject:
|
||||
"""Create a ModerationObject for unsafe content."""
|
||||
|
||||
unsafe_code_list = unsafe_code.split(",")
|
||||
openai_categories = []
|
||||
for code in unsafe_code_list:
|
||||
if code not in SAFETY_CODE_TO_CATEGORIES_MAP:
|
||||
raise ValueError(f"Unknown code: {code}")
|
||||
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP.get(code)
|
||||
openai_categories.extend(
|
||||
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
|
||||
)
|
||||
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
|
||||
|
||||
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
|
||||
category_applied_input_types = {
|
||||
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()
|
||||
}
|
||||
return ModerationObject(
|
||||
id=f"modr-{uuid.uuid4()}",
|
||||
model=model,
|
||||
results=[
|
||||
ModerationObjectResults(
|
||||
flagged=True,
|
||||
categories=categories,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
category_scores=category_scores,
|
||||
user_message=CANNED_RESPONSE_TEXT,
|
||||
metadata={"violation_type": unsafe_code_list},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
|
||||
"""Check if content is safe based on response and unsafe code."""
|
||||
if response.strip() == SAFE_RESPONSE:
|
||||
return True
|
||||
|
||||
if unsafe_code:
|
||||
unsafe_code_list = unsafe_code.split(",")
|
||||
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_moderation_object(self, response: str) -> ModerationObject:
|
||||
response = response.strip()
|
||||
if self.is_content_safe(response):
|
||||
return self.create_safe_moderation_object(self.model)
|
||||
unsafe_code = self.check_unsafe_response(response)
|
||||
if not unsafe_code:
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
if self.is_content_safe(response, unsafe_code):
|
||||
return self.create_safe_moderation_object(self.model)
|
||||
else:
|
||||
return self.create_unsafe_moderation_object(self.model, unsafe_code)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue