mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
# What does this PR do? This PR adds Open AI Compatible moderations api. Currently only implementing for llama guard safety provider Image support, expand to other safety providers and Deprecation of run_shield will be next steps. ## Test Plan Added 2 new tests for safe/ unsafe text prompt examples for the new open ai compatible moderations api usage `SAFETY_MODEL=llama-guard3:8b LLAMA_STACK_CONFIG=starter uv run pytest -v tests/integration/safety/test_safety.py --text-model=llama3.2:3b-instruct-fp16 --embedding-model=all-MiniLM-L6-v2 --safety-shield=ollama` (Had some issue with previous PR https://github.com/meta-llama/llama-stack/pull/2994 while updating and accidentally close it , reopened new one )
159 lines
5 KiB
Python
159 lines
5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from enum import Enum, StrEnum
|
|
from typing import Any, Protocol, runtime_checkable
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from llama_stack.apis.inference import Message
|
|
from llama_stack.apis.shields import Shield
|
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
|
|
|
|
|
# OpenAI Categories to return in the response
|
|
class OpenAICategories(StrEnum):
|
|
"""
|
|
Required set of categories in moderations api response
|
|
"""
|
|
|
|
VIOLENCE = "violence"
|
|
VIOLENCE_GRAPHIC = "violence/graphic"
|
|
HARRASMENT = "harassment"
|
|
HARRASMENT_THREATENING = "harassment/threatening"
|
|
HATE = "hate"
|
|
HATE_THREATENING = "hate/threatening"
|
|
ILLICIT = "illicit"
|
|
ILLICIT_VIOLENT = "illicit/violent"
|
|
SEXUAL = "sexual"
|
|
SEXUAL_MINORS = "sexual/minors"
|
|
SELF_HARM = "self-harm"
|
|
SELF_HARM_INTENT = "self-harm/intent"
|
|
SELF_HARM_INSTRUCTIONS = "self-harm/instructions"
|
|
|
|
|
|
@json_schema_type
|
|
class ModerationObjectResults(BaseModel):
|
|
"""A moderation object.
|
|
:param flagged: Whether any of the below categories are flagged.
|
|
:param categories: A list of the categories, and whether they are flagged or not.
|
|
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
|
|
:param category_scores: A list of the categories along with their scores as predicted by model.
|
|
Required set of categories that need to be in response
|
|
- violence
|
|
- violence/graphic
|
|
- harassment
|
|
- harassment/threatening
|
|
- hate
|
|
- hate/threatening
|
|
- illicit
|
|
- illicit/violent
|
|
- sexual
|
|
- sexual/minors
|
|
- self-harm
|
|
- self-harm/intent
|
|
- self-harm/instructions
|
|
"""
|
|
|
|
flagged: bool
|
|
categories: dict[str, bool] | None = None
|
|
category_applied_input_types: dict[str, list[str]] | None = None
|
|
category_scores: dict[str, float] | None = None
|
|
user_message: str | None = None
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
@json_schema_type
|
|
class ModerationObject(BaseModel):
|
|
"""A moderation object.
|
|
:param id: The unique identifier for the moderation request.
|
|
:param model: The model used to generate the moderation results.
|
|
:param results: A list of moderation objects
|
|
"""
|
|
|
|
id: str
|
|
model: str
|
|
results: list[ModerationObjectResults]
|
|
|
|
|
|
@json_schema_type
|
|
class ViolationLevel(Enum):
|
|
"""Severity level of a safety violation.
|
|
|
|
:cvar INFO: Informational level violation that does not require action
|
|
:cvar WARN: Warning level violation that suggests caution but allows continuation
|
|
:cvar ERROR: Error level violation that requires blocking or intervention
|
|
"""
|
|
|
|
INFO = "info"
|
|
WARN = "warn"
|
|
ERROR = "error"
|
|
|
|
|
|
@json_schema_type
|
|
class SafetyViolation(BaseModel):
|
|
"""Details of a safety violation detected by content moderation.
|
|
|
|
:param violation_level: Severity level of the violation
|
|
:param user_message: (Optional) Message to convey to the user about the violation
|
|
:param metadata: Additional metadata including specific violation codes for debugging and telemetry
|
|
"""
|
|
|
|
violation_level: ViolationLevel
|
|
|
|
# what message should you convey to the user
|
|
user_message: str | None = None
|
|
|
|
# additional metadata (including specific violation codes) more for
|
|
# debugging, telemetry
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
@json_schema_type
|
|
class RunShieldResponse(BaseModel):
|
|
"""Response from running a safety shield.
|
|
|
|
:param violation: (Optional) Safety violation detected by the shield, if any
|
|
"""
|
|
|
|
violation: SafetyViolation | None = None
|
|
|
|
|
|
class ShieldStore(Protocol):
|
|
async def get_shield(self, identifier: str) -> Shield: ...
|
|
|
|
|
|
@runtime_checkable
|
|
@trace_protocol
|
|
class Safety(Protocol):
|
|
shield_store: ShieldStore
|
|
|
|
@webmethod(route="/safety/run-shield", method="POST")
|
|
async def run_shield(
|
|
self,
|
|
shield_id: str,
|
|
messages: list[Message],
|
|
params: dict[str, Any],
|
|
) -> RunShieldResponse:
|
|
"""Run a shield.
|
|
|
|
:param shield_id: The identifier of the shield to run.
|
|
:param messages: The messages to run the shield on.
|
|
:param params: The parameters of the shield.
|
|
:returns: A RunShieldResponse.
|
|
"""
|
|
...
|
|
|
|
@webmethod(route="/openai/v1/moderations", method="POST")
|
|
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
|
"""Classifies if text and/or image inputs are potentially harmful.
|
|
:param input: Input (or inputs) to classify.
|
|
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
|
:param model: The content moderation model you would like to use.
|
|
:returns: A moderation object.
|
|
"""
|
|
...
|