feat: Add moderations create api (#3020)

# What does this PR do?
This PR adds Open AI Compatible moderations api. Currently only
implementing for llama guard safety provider
Image support, expand to other safety providers and Deprecation of
run_shield will be next steps.


## Test Plan
Added 2 new tests for safe/ unsafe text prompt examples for the new open
ai compatible moderations api usage
`SAFETY_MODEL=llama-guard3:8b LLAMA_STACK_CONFIG=starter uv run pytest
-v tests/integration/safety/test_safety.py
--text-model=llama3.2:3b-instruct-fp16
--embedding-model=all-MiniLM-L6-v2 --safety-shield=ollama`
(Had some issue with previous PR
https://github.com/meta-llama/llama-stack/pull/2994 while updating and
accidentally close it , reopened new one )
This commit is contained in:
slekkala1 2025-08-06 13:51:23 -07:00 committed by GitHub
parent 0caef40e0d
commit 26d3d25c87
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 622 additions and 1 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import Enum, StrEnum
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field
@ -15,6 +15,71 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
# OpenAI Categories to return in the response
class OpenAICategories(StrEnum):
"""
Required set of categories in moderations api response
"""
VIOLENCE = "violence"
VIOLENCE_GRAPHIC = "violence/graphic"
HARRASMENT = "harassment"
HARRASMENT_THREATENING = "harassment/threatening"
HATE = "hate"
HATE_THREATENING = "hate/threatening"
ILLICIT = "illicit"
ILLICIT_VIOLENT = "illicit/violent"
SEXUAL = "sexual"
SEXUAL_MINORS = "sexual/minors"
SELF_HARM = "self-harm"
SELF_HARM_INTENT = "self-harm/intent"
SELF_HARM_INSTRUCTIONS = "self-harm/instructions"
@json_schema_type
class ModerationObjectResults(BaseModel):
"""A moderation object.
:param flagged: Whether any of the below categories are flagged.
:param categories: A list of the categories, and whether they are flagged or not.
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
:param category_scores: A list of the categories along with their scores as predicted by model.
Required set of categories that need to be in response
- violence
- violence/graphic
- harassment
- harassment/threatening
- hate
- hate/threatening
- illicit
- illicit/violent
- sexual
- sexual/minors
- self-harm
- self-harm/intent
- self-harm/instructions
"""
flagged: bool
categories: dict[str, bool] | None = None
category_applied_input_types: dict[str, list[str]] | None = None
category_scores: dict[str, float] | None = None
user_message: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class ModerationObject(BaseModel):
"""A moderation object.
:param id: The unique identifier for the moderation request.
:param model: The model used to generate the moderation results.
:param results: A list of moderation objects
"""
id: str
model: str
results: list[ModerationObjectResults]
@json_schema_type
class ViolationLevel(Enum):
"""Severity level of a safety violation.
@ -82,3 +147,13 @@ class Safety(Protocol):
:returns: A RunShieldResponse.
"""
...
@webmethod(route="/openai/v1/moderations", method="POST")
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
"""Classifies if text and/or image inputs are potentially harmful.
:param input: Input (or inputs) to classify.
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
:param model: The content moderation model you would like to use.
:returns: A moderation object.
"""
...

View file

@ -10,6 +10,7 @@ from llama_stack.apis.inference import (
Message,
)
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
@ -60,3 +61,41 @@ class SafetyRouter(Safety):
messages=messages,
params=params,
)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
async def get_shield_id(self, model: str) -> str:
"""Get Shield id from model (provider_resource_id) of shield."""
list_shields_response = await self.routing_table.list_shields()
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
if not matches:
raise ValueError(f"No shield associated with provider_resource id {model}")
if len(matches) > 1:
raise ValueError(f"Multiple shields associated with provider_resource id {model}")
return matches[0]
shield_id = await get_shield_id(self, model)
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
response = await provider.run_moderation(
input=input,
model=model,
)
self._validate_required_categories_exist(response)
return response
def _validate_required_categories_exist(self, response: ModerationObject) -> None:
"""Validate the ProviderImpl response contains the required Open AI moderations categories."""
required_categories = list(map(str, OpenAICategories))
categories = response.results[0].categories
category_applied_input_types = response.results[0].category_applied_input_types
category_scores = response.results[0].category_scores
for i in [categories, category_applied_input_types, category_scores]:
if not set(required_categories).issubset(set(i.keys())):
raise ValueError(
f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}"
)

View file

@ -4,7 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import re
import uuid
from string import Template
from typing import Any
@ -20,6 +22,7 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api
from llama_stack.models.llama.datatypes import Role
@ -67,6 +70,31 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
CAT_ELECTIONS: "S13",
CAT_CODE_INTERPRETER_ABUSE: "S14",
}
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES],
OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES],
OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION],
OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
OpenAICategories.HATE: [CAT_HATE],
OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES],
OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES],
OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION],
OpenAICategories.SELF_HARM: [CAT_SELF_HARM],
OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM],
OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
# These are custom categories that are not in the OpenAI moderation categories
"custom/defamation": [CAT_DEFAMATION],
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],
"custom/privacy_violation": [CAT_PRIVACY],
"custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY],
"custom/weapons": [CAT_INDISCRIMINATE_WEAPONS],
"custom/elections": [CAT_ELECTIONS],
"custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE],
}
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
@ -194,6 +222,34 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
return await impl.run(messages)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
if isinstance(input, list):
messages = input.copy()
else:
messages = [input]
# convert to user messages format with role
messages = [UserMessage(content=m) for m in messages]
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
if model in LLAMA_GUARD_MODEL_IDS:
# Use the mapped model for categories but the original model_id for inference
mapped_model = LLAMA_GUARD_MODEL_IDS[model]
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
else:
# For unknown models, use default Llama Guard 3 8B categories
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
impl = LlamaGuardShield(
model=model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
safety_categories=safety_categories,
)
return await impl.run_moderation(messages)
class LlamaGuardShield:
def __init__(
@ -340,3 +396,117 @@ class LlamaGuardShield:
)
raise ValueError(f"Unexpected response: {response}")
async def run_moderation(self, messages: list[Message]) -> ModerationObject:
if not messages:
return self.create_moderation_object(self.model)
# TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion(
model=self.model,
messages=[shield_input_message],
stream=False,
)
content = response.choices[0].message.content
content = content.strip()
return self.get_moderation_object(content)
def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject:
"""Create a ModerationObject for either safe or unsafe content.
Args:
model: The model name
unsafe_code: Optional comma-separated list of safety codes. If None, creates safe object.
Returns:
ModerationObject with appropriate configuration
"""
# Set default values for safe case
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0)
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
flagged = False
user_message = None
metadata = {}
# Handle unsafe case
if unsafe_code:
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
if invalid_codes:
logging.warning(f"Invalid safety codes returned: {invalid_codes}")
# just returning safe object, as we don't know what the invalid codes can map to
return ModerationObject(
id=f"modr-{uuid.uuid4()}",
model=model,
results=[
ModerationObjectResults(
flagged=flagged,
categories=categories,
category_applied_input_types=category_applied_input_types,
category_scores=category_scores,
user_message=user_message,
metadata=metadata,
)
],
)
# Get OpenAI categories for the unsafe codes
openai_categories = []
for code in unsafe_code_list:
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code]
openai_categories.extend(
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
)
# Update categories for unsafe content
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
category_applied_input_types = {
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP
}
flagged = True
user_message = CANNED_RESPONSE_TEXT
metadata = {"violation_type": unsafe_code_list}
return ModerationObject(
id=f"modr-{uuid.uuid4()}",
model=model,
results=[
ModerationObjectResults(
flagged=flagged,
categories=categories,
category_applied_input_types=category_applied_input_types,
category_scores=category_scores,
user_message=user_message,
metadata=metadata,
)
],
)
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
"""Check if content is safe based on response and unsafe code."""
if response.strip() == SAFE_RESPONSE:
return True
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return True
return False
def get_moderation_object(self, response: str) -> ModerationObject:
response = response.strip()
if self.is_content_safe(response):
return self.create_moderation_object(self.model)
unsafe_code = self.check_unsafe_response(response)
if not unsafe_code:
raise ValueError(f"Unexpected response: {response}")
if self.is_content_safe(response, unsafe_code):
return self.create_moderation_object(self.model)
else:
return self.create_moderation_object(self.model, unsafe_code)