From 25e0553eed5a21773369dfd3d81316db2da39629 Mon Sep 17 00:00:00 2001 From: slekkala1 Date: Wed, 13 Aug 2025 09:47:35 -0700 Subject: [PATCH] chore: Change moderations api response to Provider returned categories (#3098) # What does this PR do? To be compliant with model policies for LLAMA, just return the categories as is from provider, we will lose the OAI compat in moderations api response. ## Test Plan `SAFETY_MODEL=llama-guard3:8b LLAMA_STACK_CONFIG=starter uv run pytest -v tests/integration/safety/test_safety.py --text-model=llama3.2:3b-instruct-fp16 --embedding-model=all-MiniLM-L6-v2 --safety-shield=ollama` --- docs/_static/llama-stack-spec.html | 2 +- docs/_static/llama-stack-spec.yaml | 4 -- llama_stack/apis/safety/safety.py | 37 +------------- llama_stack/core/routers/safety.py | 17 +------ .../inline/safety/llama_guard/llama_guard.py | 48 ++++--------------- .../safety/prompt_guard/prompt_guard.py | 5 +- 6 files changed, 16 insertions(+), 97 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index e2c53d4b0..25f916d87 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -16569,7 +16569,7 @@ "additionalProperties": { "type": "number" }, - "description": "A list of the categories along with their scores as predicted by model. Required set of categories that need to be in response - violence - violence/graphic - harassment - harassment/threatening - hate - hate/threatening - illicit - illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent - self-harm/instructions" + "description": "A list of the categories along with their scores as predicted by model." }, "user_message": { "type": "string" diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 85cec3a78..43e9fa95a 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -12322,10 +12322,6 @@ components: type: number description: >- A list of the categories along with their scores as predicted by model. - Required set of categories that need to be in response - violence - violence/graphic - - harassment - harassment/threatening - hate - hate/threatening - illicit - - illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent - - self-harm/instructions user_message: type: string metadata: diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 3f374460b..25ee03ec1 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum, StrEnum +from enum import Enum from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -15,27 +15,6 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod -# OpenAI Categories to return in the response -class OpenAICategories(StrEnum): - """ - Required set of categories in moderations api response - """ - - VIOLENCE = "violence" - VIOLENCE_GRAPHIC = "violence/graphic" - HARRASMENT = "harassment" - HARRASMENT_THREATENING = "harassment/threatening" - HATE = "hate" - HATE_THREATENING = "hate/threatening" - ILLICIT = "illicit" - ILLICIT_VIOLENT = "illicit/violent" - SEXUAL = "sexual" - SEXUAL_MINORS = "sexual/minors" - SELF_HARM = "self-harm" - SELF_HARM_INTENT = "self-harm/intent" - SELF_HARM_INSTRUCTIONS = "self-harm/instructions" - - @json_schema_type class ModerationObjectResults(BaseModel): """A moderation object. @@ -43,20 +22,6 @@ class ModerationObjectResults(BaseModel): :param categories: A list of the categories, and whether they are flagged or not. :param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to. :param category_scores: A list of the categories along with their scores as predicted by model. - Required set of categories that need to be in response - - violence - - violence/graphic - - harassment - - harassment/threatening - - hate - - hate/threatening - - illicit - - illicit/violent - - sexual - - sexual/minors - - self-harm - - self-harm/intent - - self-harm/instructions """ flagged: bool diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index 9bf2b1bac..c76673d2a 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -10,7 +10,7 @@ from llama_stack.apis.inference import ( Message, ) from llama_stack.apis.safety import RunShieldResponse, Safety -from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories +from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable @@ -82,20 +82,5 @@ class SafetyRouter(Safety): input=input, model=model, ) - self._validate_required_categories_exist(response) return response - - def _validate_required_categories_exist(self, response: ModerationObject) -> None: - """Validate the ProviderImpl response contains the required Open AI moderations categories.""" - required_categories = list(map(str, OpenAICategories)) - - categories = response.results[0].categories - category_applied_input_types = response.results[0].category_applied_input_types - category_scores = response.results[0].category_scores - - for i in [categories, category_applied_input_types, category_scores]: - if not set(required_categories).issubset(set(i.keys())): - raise ValueError( - f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}" - ) diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index f83c39a6a..bae744010 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -22,7 +22,7 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) -from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories +from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.shields import Shield from llama_stack.core.datatypes import Api from llama_stack.models.llama.datatypes import Role @@ -72,30 +72,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = { } SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()} -OPENAI_TO_LLAMA_CATEGORIES_MAP = { - OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES], - OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES], - OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION], - OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION], - OpenAICategories.HATE: [CAT_HATE], - OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES], - OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES], - OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS], - OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT], - OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION], - OpenAICategories.SELF_HARM: [CAT_SELF_HARM], - OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM], - OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE], - # These are custom categories that are not in the OpenAI moderation categories - "custom/defamation": [CAT_DEFAMATION], - "custom/specialized_advice": [CAT_SPECIALIZED_ADVICE], - "custom/privacy_violation": [CAT_PRIVACY], - "custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY], - "custom/weapons": [CAT_INDISCRIMINATE_WEAPONS], - "custom/elections": [CAT_ELECTIONS], - "custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE], -} - DEFAULT_LG_V3_SAFETY_CATEGORIES = [ CAT_VIOLENT_CRIMES, @@ -424,9 +400,9 @@ class LlamaGuardShield: ModerationObject with appropriate configuration """ # Set default values for safe case - categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False) - category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0) - category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()} + categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False) + category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0) + category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()} flagged = False user_message = None metadata = {} @@ -453,19 +429,15 @@ class LlamaGuardShield: ], ) - # Get OpenAI categories for the unsafe codes - openai_categories = [] - for code in unsafe_code_list: - llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code] - openai_categories.extend( - k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l - ) + llama_guard_category = [SAFETY_CODE_TO_CATEGORIES_MAP[code] for code in unsafe_code_list] # Update categories for unsafe content - categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP} - category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP} + categories = {k: k in llama_guard_category for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()} + category_scores = { + k: 1.0 if k in llama_guard_category else 0.0 for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys() + } category_applied_input_types = { - k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP + k: ["text"] if k in llama_guard_category else [] for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys() } flagged = True user_message = CANNED_RESPONSE_TEXT diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 801500dee..c760f0fd1 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -18,6 +18,7 @@ from llama_stack.apis.safety import ( ShieldStore, ViolationLevel, ) +from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield from llama_stack.core.utils.model_utils import model_local_dir from llama_stack.providers.datatypes import ShieldsProtocolPrivate @@ -64,8 +65,8 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): return await self.shield.run(messages) - async def run_moderation(self, input: str | list[str], model: str): - raise NotImplementedError("run_moderation not implemented for PromptGuard") + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + raise NotImplementedError("run_moderation is not implemented for Prompt Guard") class PromptGuardShield: