mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
feat: Add moderations create api (#3020)
# What does this PR do? This PR adds Open AI Compatible moderations api. Currently only implementing for llama guard safety provider Image support, expand to other safety providers and Deprecation of run_shield will be next steps. ## Test Plan Added 2 new tests for safe/ unsafe text prompt examples for the new open ai compatible moderations api usage `SAFETY_MODEL=llama-guard3:8b LLAMA_STACK_CONFIG=starter uv run pytest -v tests/integration/safety/test_safety.py --text-model=llama3.2:3b-instruct-fp16 --embedding-model=all-MiniLM-L6-v2 --safety-shield=ollama` (Had some issue with previous PR https://github.com/meta-llama/llama-stack/pull/2994 while updating and accidentally close it , reopened new one )
This commit is contained in:
parent
0caef40e0d
commit
26d3d25c87
6 changed files with 622 additions and 1 deletions
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -15,6 +15,71 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
# OpenAI Categories to return in the response
|
||||
class OpenAICategories(StrEnum):
|
||||
"""
|
||||
Required set of categories in moderations api response
|
||||
"""
|
||||
|
||||
VIOLENCE = "violence"
|
||||
VIOLENCE_GRAPHIC = "violence/graphic"
|
||||
HARRASMENT = "harassment"
|
||||
HARRASMENT_THREATENING = "harassment/threatening"
|
||||
HATE = "hate"
|
||||
HATE_THREATENING = "hate/threatening"
|
||||
ILLICIT = "illicit"
|
||||
ILLICIT_VIOLENT = "illicit/violent"
|
||||
SEXUAL = "sexual"
|
||||
SEXUAL_MINORS = "sexual/minors"
|
||||
SELF_HARM = "self-harm"
|
||||
SELF_HARM_INTENT = "self-harm/intent"
|
||||
SELF_HARM_INSTRUCTIONS = "self-harm/instructions"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModerationObjectResults(BaseModel):
|
||||
"""A moderation object.
|
||||
:param flagged: Whether any of the below categories are flagged.
|
||||
:param categories: A list of the categories, and whether they are flagged or not.
|
||||
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
|
||||
:param category_scores: A list of the categories along with their scores as predicted by model.
|
||||
Required set of categories that need to be in response
|
||||
- violence
|
||||
- violence/graphic
|
||||
- harassment
|
||||
- harassment/threatening
|
||||
- hate
|
||||
- hate/threatening
|
||||
- illicit
|
||||
- illicit/violent
|
||||
- sexual
|
||||
- sexual/minors
|
||||
- self-harm
|
||||
- self-harm/intent
|
||||
- self-harm/instructions
|
||||
"""
|
||||
|
||||
flagged: bool
|
||||
categories: dict[str, bool] | None = None
|
||||
category_applied_input_types: dict[str, list[str]] | None = None
|
||||
category_scores: dict[str, float] | None = None
|
||||
user_message: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModerationObject(BaseModel):
|
||||
"""A moderation object.
|
||||
:param id: The unique identifier for the moderation request.
|
||||
:param model: The model used to generate the moderation results.
|
||||
:param results: A list of moderation objects
|
||||
"""
|
||||
|
||||
id: str
|
||||
model: str
|
||||
results: list[ModerationObjectResults]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ViolationLevel(Enum):
|
||||
"""Severity level of a safety violation.
|
||||
|
@ -82,3 +147,13 @@ class Safety(Protocol):
|
|||
:returns: A RunShieldResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/moderations", method="POST")
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
"""Classifies if text and/or image inputs are potentially harmful.
|
||||
:param input: Input (or inputs) to classify.
|
||||
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
||||
:param model: The content moderation model you would like to use.
|
||||
:returns: A moderation object.
|
||||
"""
|
||||
...
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue