llama-stack-mirror/llama_stack/apis/safety/safety.py
slekkala1 26d3d25c87
feat: Add moderations create api (#3020)
# What does this PR do?
This PR adds Open AI Compatible moderations api. Currently only
implementing for llama guard safety provider
Image support, expand to other safety providers and Deprecation of
run_shield will be next steps.


## Test Plan
Added 2 new tests for safe/ unsafe text prompt examples for the new open
ai compatible moderations api usage
`SAFETY_MODEL=llama-guard3:8b LLAMA_STACK_CONFIG=starter uv run pytest
-v tests/integration/safety/test_safety.py
--text-model=llama3.2:3b-instruct-fp16
--embedding-model=all-MiniLM-L6-v2 --safety-shield=ollama`
(Had some issue with previous PR
https://github.com/meta-llama/llama-stack/pull/2994 while updating and
accidentally close it , reopened new one )
2025-08-06 13:51:23 -07:00

159 lines
5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, StrEnum
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Message
from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
# OpenAI Categories to return in the response
class OpenAICategories(StrEnum):
"""
Required set of categories in moderations api response
"""
VIOLENCE = "violence"
VIOLENCE_GRAPHIC = "violence/graphic"
HARRASMENT = "harassment"
HARRASMENT_THREATENING = "harassment/threatening"
HATE = "hate"
HATE_THREATENING = "hate/threatening"
ILLICIT = "illicit"
ILLICIT_VIOLENT = "illicit/violent"
SEXUAL = "sexual"
SEXUAL_MINORS = "sexual/minors"
SELF_HARM = "self-harm"
SELF_HARM_INTENT = "self-harm/intent"
SELF_HARM_INSTRUCTIONS = "self-harm/instructions"
@json_schema_type
class ModerationObjectResults(BaseModel):
"""A moderation object.
:param flagged: Whether any of the below categories are flagged.
:param categories: A list of the categories, and whether they are flagged or not.
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
:param category_scores: A list of the categories along with their scores as predicted by model.
Required set of categories that need to be in response
- violence
- violence/graphic
- harassment
- harassment/threatening
- hate
- hate/threatening
- illicit
- illicit/violent
- sexual
- sexual/minors
- self-harm
- self-harm/intent
- self-harm/instructions
"""
flagged: bool
categories: dict[str, bool] | None = None
category_applied_input_types: dict[str, list[str]] | None = None
category_scores: dict[str, float] | None = None
user_message: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class ModerationObject(BaseModel):
"""A moderation object.
:param id: The unique identifier for the moderation request.
:param model: The model used to generate the moderation results.
:param results: A list of moderation objects
"""
id: str
model: str
results: list[ModerationObjectResults]
@json_schema_type
class ViolationLevel(Enum):
"""Severity level of a safety violation.
:cvar INFO: Informational level violation that does not require action
:cvar WARN: Warning level violation that suggests caution but allows continuation
:cvar ERROR: Error level violation that requires blocking or intervention
"""
INFO = "info"
WARN = "warn"
ERROR = "error"
@json_schema_type
class SafetyViolation(BaseModel):
"""Details of a safety violation detected by content moderation.
:param violation_level: Severity level of the violation
:param user_message: (Optional) Message to convey to the user about the violation
:param metadata: Additional metadata including specific violation codes for debugging and telemetry
"""
violation_level: ViolationLevel
# what message should you convey to the user
user_message: str | None = None
# additional metadata (including specific violation codes) more for
# debugging, telemetry
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RunShieldResponse(BaseModel):
"""Response from running a safety shield.
:param violation: (Optional) Safety violation detected by the shield, if any
"""
violation: SafetyViolation | None = None
class ShieldStore(Protocol):
async def get_shield(self, identifier: str) -> Shield: ...
@runtime_checkable
@trace_protocol
class Safety(Protocol):
shield_store: ShieldStore
@webmethod(route="/safety/run-shield", method="POST")
async def run_shield(
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any],
) -> RunShieldResponse:
"""Run a shield.
:param shield_id: The identifier of the shield to run.
:param messages: The messages to run the shield on.
:param params: The parameters of the shield.
:returns: A RunShieldResponse.
"""
...
@webmethod(route="/openai/v1/moderations", method="POST")
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
"""Classifies if text and/or image inputs are potentially harmful.
:param input: Input (or inputs) to classify.
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
:param model: The content moderation model you would like to use.
:returns: A moderation object.
"""
...