Add validation for required OAI categories

This commit is contained in:
Swapna Lekkala 2025-08-06 12:10:30 -07:00
parent 215c8d91b8
commit 3b690b9ae6
3 changed files with 55 additions and 17 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum, StrEnum
from typing import Any, Protocol, runtime_checkable from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -15,6 +15,27 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import json_schema_type, webmethod
# OpenAI Categories to return in the response
class OpenAICategories(StrEnum):
"""
Required set of categories in moderations api response
"""
VIOLENCE = "violence"
VIOLENCE_GRAPHIC = "violence/graphic"
HARRASMENT = "harassment"
HARRASMENT_THREATENING = "harassment/threatening"
HATE = "hate"
HATE_THREATENING = "hate/threatening"
ILLICIT = "illicit"
ILLICIT_VIOLENT = "illicit/violent"
SEXUAL = "sexual"
SEXUAL_MINORS = "sexual/minors"
SELF_HARM = "self-harm"
SELF_HARM_INTENT = "self-harm/intent"
SELF_HARM_INSTRUCTIONS = "self-harm/instructions"
@json_schema_type @json_schema_type
class ModerationObjectResults(BaseModel): class ModerationObjectResults(BaseModel):
"""A moderation object. """A moderation object.

View file

@ -10,7 +10,7 @@ from llama_stack.apis.inference import (
Message, Message,
) )
from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
@ -78,7 +78,24 @@ class SafetyRouter(Safety):
logger.debug(f"SafetyRouter.run_moderation: {shield_id}") logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id) provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_moderation( response = await provider.run_moderation(
input=input, input=input,
model=model, model=model,
) )
self._validate_required_categories_exist(response)
return response
def _validate_required_categories_exist(self, response: ModerationObject) -> None:
"""Validate the ProviderImpl response contains the required Open AI moderations categories."""
required_categories = list(map(str, OpenAICategories))
categories = response.results[0].categories
category_applied_input_types = response.results[0].category_applied_input_types
category_scores = response.results[0].category_scores
for i in [categories, category_applied_input_types, category_scores]:
if not set(required_categories).issubset(set(i.keys())):
raise ValueError(
f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}"
)

View file

@ -22,7 +22,7 @@ from llama_stack.apis.safety import (
SafetyViolation, SafetyViolation,
ViolationLevel, ViolationLevel,
) )
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api from llama_stack.core.datatypes import Api
from llama_stack.models.llama.datatypes import Role from llama_stack.models.llama.datatypes import Role
@ -73,19 +73,19 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()} SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
OPENAI_TO_LLAMA_CATEGORIES_MAP = { OPENAI_TO_LLAMA_CATEGORIES_MAP = {
"violence": [CAT_VIOLENT_CRIMES], OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES],
"violence/graphic": [CAT_VIOLENT_CRIMES], OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES],
"harassment": [CAT_CHILD_EXPLOITATION], OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION],
"harassment/threatening": [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION], OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
"hate": [CAT_HATE], OpenAICategories.HATE: [CAT_HATE],
"hate/threatening": [CAT_HATE, CAT_VIOLENT_CRIMES], OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES],
"illicit": [CAT_NON_VIOLENT_CRIMES], OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES],
"illicit/violent": [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS], OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
"sexual": [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT], OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
"sexual/minors": [CAT_CHILD_EXPLOITATION], OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION],
"self-harm": [CAT_SELF_HARM], OpenAICategories.SELF_HARM: [CAT_SELF_HARM],
"self-harm/intent": [CAT_SELF_HARM], OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM],
"self-harm/instructions": [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE], OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
# These are custom categories that are not in the OpenAI moderation categories # These are custom categories that are not in the OpenAI moderation categories
"custom/defamation": [CAT_DEFAMATION], "custom/defamation": [CAT_DEFAMATION],
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE], "custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],