Add validation for required OAI categories

This commit is contained in:
Swapna Lekkala 2025-08-06 12:10:30 -07:00
parent 215c8d91b8
commit 3b690b9ae6
3 changed files with 55 additions and 17 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import Enum, StrEnum
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field
@ -15,6 +15,27 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
# OpenAI Categories to return in the response
class OpenAICategories(StrEnum):
"""
Required set of categories in moderations api response
"""
VIOLENCE = "violence"
VIOLENCE_GRAPHIC = "violence/graphic"
HARRASMENT = "harassment"
HARRASMENT_THREATENING = "harassment/threatening"
HATE = "hate"
HATE_THREATENING = "hate/threatening"
ILLICIT = "illicit"
ILLICIT_VIOLENT = "illicit/violent"
SEXUAL = "sexual"
SEXUAL_MINORS = "sexual/minors"
SELF_HARM = "self-harm"
SELF_HARM_INTENT = "self-harm/intent"
SELF_HARM_INSTRUCTIONS = "self-harm/instructions"
@json_schema_type
class ModerationObjectResults(BaseModel):
"""A moderation object.

View file

@ -10,7 +10,7 @@ from llama_stack.apis.inference import (
Message,
)
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
@ -78,7 +78,24 @@ class SafetyRouter(Safety):
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_moderation(
response = await provider.run_moderation(
input=input,
model=model,
)
self._validate_required_categories_exist(response)
return response
def _validate_required_categories_exist(self, response: ModerationObject) -> None:
"""Validate the ProviderImpl response contains the required Open AI moderations categories."""
required_categories = list(map(str, OpenAICategories))
categories = response.results[0].categories
category_applied_input_types = response.results[0].category_applied_input_types
category_scores = response.results[0].category_scores
for i in [categories, category_applied_input_types, category_scores]:
if not set(required_categories).issubset(set(i.keys())):
raise ValueError(
f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}"
)

View file

@ -22,7 +22,7 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api
from llama_stack.models.llama.datatypes import Role
@ -73,19 +73,19 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
"violence": [CAT_VIOLENT_CRIMES],
"violence/graphic": [CAT_VIOLENT_CRIMES],
"harassment": [CAT_CHILD_EXPLOITATION],
"harassment/threatening": [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
"hate": [CAT_HATE],
"hate/threatening": [CAT_HATE, CAT_VIOLENT_CRIMES],
"illicit": [CAT_NON_VIOLENT_CRIMES],
"illicit/violent": [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
"sexual": [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
"sexual/minors": [CAT_CHILD_EXPLOITATION],
"self-harm": [CAT_SELF_HARM],
"self-harm/intent": [CAT_SELF_HARM],
"self-harm/instructions": [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES],
OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES],
OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION],
OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
OpenAICategories.HATE: [CAT_HATE],
OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES],
OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES],
OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION],
OpenAICategories.SELF_HARM: [CAT_SELF_HARM],
OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM],
OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
# These are custom categories that are not in the OpenAI moderation categories
"custom/defamation": [CAT_DEFAMATION],
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],