From 3b690b9ae6a17eac28a1019b3c20042333163fa8 Mon Sep 17 00:00:00 2001 From: Swapna Lekkala Date: Wed, 6 Aug 2025 12:10:30 -0700 Subject: [PATCH] Add validation for required OAI categories --- llama_stack/apis/safety/safety.py | 23 ++++++++++++++- llama_stack/core/routers/safety.py | 21 ++++++++++++-- .../inline/safety/llama_guard/llama_guard.py | 28 +++++++++---------- 3 files changed, 55 insertions(+), 17 deletions(-) diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 87cefe8e4..3f374460b 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -15,6 +15,27 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod +# OpenAI Categories to return in the response +class OpenAICategories(StrEnum): + """ + Required set of categories in moderations api response + """ + + VIOLENCE = "violence" + VIOLENCE_GRAPHIC = "violence/graphic" + HARRASMENT = "harassment" + HARRASMENT_THREATENING = "harassment/threatening" + HATE = "hate" + HATE_THREATENING = "hate/threatening" + ILLICIT = "illicit" + ILLICIT_VIOLENT = "illicit/violent" + SEXUAL = "sexual" + SEXUAL_MINORS = "sexual/minors" + SELF_HARM = "self-harm" + SELF_HARM_INTENT = "self-harm/intent" + SELF_HARM_INSTRUCTIONS = "self-harm/instructions" + + @json_schema_type class ModerationObjectResults(BaseModel): """A moderation object. diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index b8bce274e..9bf2b1bac 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -10,7 +10,7 @@ from llama_stack.apis.inference import ( Message, ) from llama_stack.apis.safety import RunShieldResponse, Safety -from llama_stack.apis.safety.safety import ModerationObject +from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories from llama_stack.apis.shields import Shield from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable @@ -78,7 +78,24 @@ class SafetyRouter(Safety): logger.debug(f"SafetyRouter.run_moderation: {shield_id}") provider = await self.routing_table.get_provider_impl(shield_id) - return await provider.run_moderation( + response = await provider.run_moderation( input=input, model=model, ) + self._validate_required_categories_exist(response) + + return response + + def _validate_required_categories_exist(self, response: ModerationObject) -> None: + """Validate the ProviderImpl response contains the required Open AI moderations categories.""" + required_categories = list(map(str, OpenAICategories)) + + categories = response.results[0].categories + category_applied_input_types = response.results[0].category_applied_input_types + category_scores = response.results[0].category_scores + + for i in [categories, category_applied_input_types, category_scores]: + if not set(required_categories).issubset(set(i.keys())): + raise ValueError( + f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}" + ) diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 19a27a661..f83c39a6a 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -22,7 +22,7 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) -from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults +from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories from llama_stack.apis.shields import Shield from llama_stack.core.datatypes import Api from llama_stack.models.llama.datatypes import Role @@ -73,19 +73,19 @@ SAFETY_CATEGORIES_TO_CODE_MAP = { SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()} OPENAI_TO_LLAMA_CATEGORIES_MAP = { - "violence": [CAT_VIOLENT_CRIMES], - "violence/graphic": [CAT_VIOLENT_CRIMES], - "harassment": [CAT_CHILD_EXPLOITATION], - "harassment/threatening": [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION], - "hate": [CAT_HATE], - "hate/threatening": [CAT_HATE, CAT_VIOLENT_CRIMES], - "illicit": [CAT_NON_VIOLENT_CRIMES], - "illicit/violent": [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS], - "sexual": [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT], - "sexual/minors": [CAT_CHILD_EXPLOITATION], - "self-harm": [CAT_SELF_HARM], - "self-harm/intent": [CAT_SELF_HARM], - "self-harm/instructions": [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE], + OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES], + OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES], + OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION], + OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION], + OpenAICategories.HATE: [CAT_HATE], + OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES], + OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES], + OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS], + OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT], + OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION], + OpenAICategories.SELF_HARM: [CAT_SELF_HARM], + OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM], + OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE], # These are custom categories that are not in the OpenAI moderation categories "custom/defamation": [CAT_DEFAMATION], "custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],