mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
chore: Change moderations api response to Provider returned categories (#3098)
# What does this PR do? To be compliant with model policies for LLAMA, just return the categories as is from provider, we will lose the OAI compat in moderations api response. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan `SAFETY_MODEL=llama-guard3:8b LLAMA_STACK_CONFIG=starter uv run pytest -v tests/integration/safety/test_safety.py --text-model=llama3.2:3b-instruct-fp16 --embedding-model=all-MiniLM-L6-v2 --safety-shield=ollama`
This commit is contained in:
parent
a9081d87b9
commit
25e0553eed
6 changed files with 16 additions and 97 deletions
2
docs/_static/llama-stack-spec.html
vendored
2
docs/_static/llama-stack-spec.html
vendored
|
@ -16569,7 +16569,7 @@
|
|||
"additionalProperties": {
|
||||
"type": "number"
|
||||
},
|
||||
"description": "A list of the categories along with their scores as predicted by model. Required set of categories that need to be in response - violence - violence/graphic - harassment - harassment/threatening - hate - hate/threatening - illicit - illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent - self-harm/instructions"
|
||||
"description": "A list of the categories along with their scores as predicted by model."
|
||||
},
|
||||
"user_message": {
|
||||
"type": "string"
|
||||
|
|
4
docs/_static/llama-stack-spec.yaml
vendored
4
docs/_static/llama-stack-spec.yaml
vendored
|
@ -12322,10 +12322,6 @@ components:
|
|||
type: number
|
||||
description: >-
|
||||
A list of the categories along with their scores as predicted by model.
|
||||
Required set of categories that need to be in response - violence - violence/graphic
|
||||
- harassment - harassment/threatening - hate - hate/threatening - illicit
|
||||
- illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent
|
||||
- self-harm/instructions
|
||||
user_message:
|
||||
type: string
|
||||
metadata:
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum, StrEnum
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -15,27 +15,6 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
# OpenAI Categories to return in the response
|
||||
class OpenAICategories(StrEnum):
|
||||
"""
|
||||
Required set of categories in moderations api response
|
||||
"""
|
||||
|
||||
VIOLENCE = "violence"
|
||||
VIOLENCE_GRAPHIC = "violence/graphic"
|
||||
HARRASMENT = "harassment"
|
||||
HARRASMENT_THREATENING = "harassment/threatening"
|
||||
HATE = "hate"
|
||||
HATE_THREATENING = "hate/threatening"
|
||||
ILLICIT = "illicit"
|
||||
ILLICIT_VIOLENT = "illicit/violent"
|
||||
SEXUAL = "sexual"
|
||||
SEXUAL_MINORS = "sexual/minors"
|
||||
SELF_HARM = "self-harm"
|
||||
SELF_HARM_INTENT = "self-harm/intent"
|
||||
SELF_HARM_INSTRUCTIONS = "self-harm/instructions"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModerationObjectResults(BaseModel):
|
||||
"""A moderation object.
|
||||
|
@ -43,20 +22,6 @@ class ModerationObjectResults(BaseModel):
|
|||
:param categories: A list of the categories, and whether they are flagged or not.
|
||||
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
|
||||
:param category_scores: A list of the categories along with their scores as predicted by model.
|
||||
Required set of categories that need to be in response
|
||||
- violence
|
||||
- violence/graphic
|
||||
- harassment
|
||||
- harassment/threatening
|
||||
- hate
|
||||
- hate/threatening
|
||||
- illicit
|
||||
- illicit/violent
|
||||
- sexual
|
||||
- sexual/minors
|
||||
- self-harm
|
||||
- self-harm/intent
|
||||
- self-harm/instructions
|
||||
"""
|
||||
|
||||
flagged: bool
|
||||
|
|
|
@ -10,7 +10,7 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
)
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
@ -82,20 +82,5 @@ class SafetyRouter(Safety):
|
|||
input=input,
|
||||
model=model,
|
||||
)
|
||||
self._validate_required_categories_exist(response)
|
||||
|
||||
return response
|
||||
|
||||
def _validate_required_categories_exist(self, response: ModerationObject) -> None:
|
||||
"""Validate the ProviderImpl response contains the required Open AI moderations categories."""
|
||||
required_categories = list(map(str, OpenAICategories))
|
||||
|
||||
categories = response.results[0].categories
|
||||
category_applied_input_types = response.results[0].category_applied_input_types
|
||||
category_scores = response.results[0].category_scores
|
||||
|
||||
for i in [categories, category_applied_input_types, category_scores]:
|
||||
if not set(required_categories).issubset(set(i.keys())):
|
||||
raise ValueError(
|
||||
f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}"
|
||||
)
|
||||
|
|
|
@ -22,7 +22,7 @@ from llama_stack.apis.safety import (
|
|||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.models.llama.datatypes import Role
|
||||
|
@ -72,30 +72,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
|
|||
}
|
||||
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
|
||||
|
||||
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
|
||||
OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES],
|
||||
OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES],
|
||||
OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION],
|
||||
OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
|
||||
OpenAICategories.HATE: [CAT_HATE],
|
||||
OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES],
|
||||
OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES],
|
||||
OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
|
||||
OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
|
||||
OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION],
|
||||
OpenAICategories.SELF_HARM: [CAT_SELF_HARM],
|
||||
OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM],
|
||||
OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
|
||||
# These are custom categories that are not in the OpenAI moderation categories
|
||||
"custom/defamation": [CAT_DEFAMATION],
|
||||
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],
|
||||
"custom/privacy_violation": [CAT_PRIVACY],
|
||||
"custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY],
|
||||
"custom/weapons": [CAT_INDISCRIMINATE_WEAPONS],
|
||||
"custom/elections": [CAT_ELECTIONS],
|
||||
"custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE],
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
||||
CAT_VIOLENT_CRIMES,
|
||||
|
@ -424,9 +400,9 @@ class LlamaGuardShield:
|
|||
ModerationObject with appropriate configuration
|
||||
"""
|
||||
# Set default values for safe case
|
||||
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
|
||||
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0)
|
||||
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
|
||||
categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False)
|
||||
category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0)
|
||||
category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
|
||||
flagged = False
|
||||
user_message = None
|
||||
metadata = {}
|
||||
|
@ -453,19 +429,15 @@ class LlamaGuardShield:
|
|||
],
|
||||
)
|
||||
|
||||
# Get OpenAI categories for the unsafe codes
|
||||
openai_categories = []
|
||||
for code in unsafe_code_list:
|
||||
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code]
|
||||
openai_categories.extend(
|
||||
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
|
||||
)
|
||||
llama_guard_category = [SAFETY_CODE_TO_CATEGORIES_MAP[code] for code in unsafe_code_list]
|
||||
|
||||
# Update categories for unsafe content
|
||||
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
|
||||
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
|
||||
categories = {k: k in llama_guard_category for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
|
||||
category_scores = {
|
||||
k: 1.0 if k in llama_guard_category else 0.0 for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
|
||||
}
|
||||
category_applied_input_types = {
|
||||
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP
|
||||
k: ["text"] if k in llama_guard_category else [] for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
|
||||
}
|
||||
flagged = True
|
||||
user_message = CANNED_RESPONSE_TEXT
|
||||
|
|
|
@ -18,6 +18,7 @@ from llama_stack.apis.safety import (
|
|||
ShieldStore,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
@ -64,8 +65,8 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
|
||||
return await self.shield.run(messages)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str):
|
||||
raise NotImplementedError("run_moderation not implemented for PromptGuard")
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
|
||||
|
||||
|
||||
class PromptGuardShield:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue