feat: Add open ai compatible moderations api

This commit is contained in:
Swapna Lekkala 2025-08-01 15:54:00 -07:00
parent 0caef40e0d
commit c89fb40082
6 changed files with 549 additions and 0 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import re
import uuid
from string import Template
from typing import Any
@ -20,6 +21,7 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api
from llama_stack.models.llama.datatypes import Role
@ -67,6 +69,31 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
CAT_ELECTIONS: "S13",
CAT_CODE_INTERPRETER_ABUSE: "S14",
}
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
"violence": [CAT_VIOLENT_CRIMES],
"violence/graphic": [CAT_VIOLENT_CRIMES],
"harassment": [CAT_CHILD_EXPLOITATION],
"harassment/threatening": [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
"hate": [CAT_HATE],
"hate/threatening": [CAT_HATE, CAT_VIOLENT_CRIMES],
"illicit": [CAT_NON_VIOLENT_CRIMES],
"illicit/violent": [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
"sexual": [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
"sexual/minors": [CAT_CHILD_EXPLOITATION],
"self-harm": [CAT_SELF_HARM],
"self-harm/intent": [CAT_SELF_HARM],
"self-harm/instructions": [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
# These are custom categories that are not in the OpenAI moderation categories
"custom/defamation": [CAT_DEFAMATION],
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],
"custom/privacy_violation": [CAT_PRIVACY],
"custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY],
"custom/weapons": [CAT_INDISCRIMINATE_WEAPONS],
"custom/elections": [CAT_ELECTIONS],
"custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE],
}
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
@ -195,6 +222,45 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
return await impl.run(messages)
async def create(
self,
input: str | list[str],
model: str | None = None, # To replace with default model for llama-guard
) -> ModerationObject:
if model is None:
raise ValueError("Model cannot be None")
if not input or len(input) == 0:
raise ValueError("Input cannot be empty")
if isinstance(input, list):
messages = input.copy()
else:
messages = [input]
# convert to user messages format with role
messages = [UserMessage(content=m) for m in messages]
# Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
if model in LLAMA_GUARD_MODEL_IDS:
# Use the mapped model for categories but the original model_id for inference
mapped_model = LLAMA_GUARD_MODEL_IDS[model]
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
else:
# For unknown models, use default Llama Guard 3 8B categories
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
impl = LlamaGuardShield(
model=model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
safety_categories=safety_categories,
)
return await impl.run_create(messages)
class LlamaGuardShield:
def __init__(
self,
@ -340,3 +406,94 @@ class LlamaGuardShield:
)
raise ValueError(f"Unexpected response: {response}")
async def run_create(self, messages: list[Message]) -> ModerationObject:
# TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages)
# TODO: llama-stack inference protocol has issues with non-streaming inference code
response = await self.inference_api.chat_completion(
model_id=self.model,
messages=[shield_input_message],
stream=False,
)
content = response.completion_message.content
content = content.strip()
return self.get_moderation_object(content)
def create_safe_moderation_object(self, model: str) -> ModerationObject:
"""Create a ModerationObject for safe content."""
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 0.0)
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
return ModerationObject(
id=f"modr-{uuid.uuid4()}",
model=model,
results=[
ModerationObjectResults(
flagged=False,
categories=categories,
category_applied_input_types=category_applied_input_types,
category_scores=category_scores,
)
],
)
def create_unsafe_moderation_object(self, model: str, unsafe_code: str) -> ModerationObject:
"""Create a ModerationObject for unsafe content."""
unsafe_code_list = unsafe_code.split(",")
openai_categories = []
for code in unsafe_code_list:
if code not in SAFETY_CODE_TO_CATEGORIES_MAP:
raise ValueError(f"Unknown code: {code}")
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP.get(code)
openai_categories.extend(
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
)
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
category_applied_input_types = {
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()
}
return ModerationObject(
id=f"modr-{uuid.uuid4()}",
model=model,
results=[
ModerationObjectResults(
flagged=True,
categories=categories,
category_applied_input_types=category_applied_input_types,
category_scores=category_scores,
user_message=CANNED_RESPONSE_TEXT,
metadata={"violation_type": unsafe_code_list},
)
],
)
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
"""Check if content is safe based on response and unsafe code."""
if response.strip() == SAFE_RESPONSE:
return True
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return True
return False
def get_moderation_object(self, response: str) -> ModerationObject:
response = response.strip()
if self.is_content_safe(response):
return self.create_safe_moderation_object(self.model)
unsafe_code = self.check_unsafe_response(response)
if not unsafe_code:
raise ValueError(f"Unexpected response: {response}")
if self.is_content_safe(response, unsafe_code):
return self.create_safe_moderation_object(self.model)
else:
return self.create_unsafe_moderation_object(self.model, unsafe_code)