mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
Add validation for required OAI categories
This commit is contained in:
parent
215c8d91b8
commit
3b690b9ae6
3 changed files with 55 additions and 17 deletions
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum, StrEnum
|
||||||
from typing import Any, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -15,6 +15,27 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
# OpenAI Categories to return in the response
|
||||||
|
class OpenAICategories(StrEnum):
|
||||||
|
"""
|
||||||
|
Required set of categories in moderations api response
|
||||||
|
"""
|
||||||
|
|
||||||
|
VIOLENCE = "violence"
|
||||||
|
VIOLENCE_GRAPHIC = "violence/graphic"
|
||||||
|
HARRASMENT = "harassment"
|
||||||
|
HARRASMENT_THREATENING = "harassment/threatening"
|
||||||
|
HATE = "hate"
|
||||||
|
HATE_THREATENING = "hate/threatening"
|
||||||
|
ILLICIT = "illicit"
|
||||||
|
ILLICIT_VIOLENT = "illicit/violent"
|
||||||
|
SEXUAL = "sexual"
|
||||||
|
SEXUAL_MINORS = "sexual/minors"
|
||||||
|
SELF_HARM = "self-harm"
|
||||||
|
SELF_HARM_INTENT = "self-harm/intent"
|
||||||
|
SELF_HARM_INSTRUCTIONS = "self-harm/instructions"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ModerationObjectResults(BaseModel):
|
class ModerationObjectResults(BaseModel):
|
||||||
"""A moderation object.
|
"""A moderation object.
|
||||||
|
|
|
@ -10,7 +10,7 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.safety.safety import ModerationObject
|
from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
@ -78,7 +78,24 @@ class SafetyRouter(Safety):
|
||||||
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
|
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||||
|
|
||||||
return await provider.run_moderation(
|
response = await provider.run_moderation(
|
||||||
input=input,
|
input=input,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
self._validate_required_categories_exist(response)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _validate_required_categories_exist(self, response: ModerationObject) -> None:
|
||||||
|
"""Validate the ProviderImpl response contains the required Open AI moderations categories."""
|
||||||
|
required_categories = list(map(str, OpenAICategories))
|
||||||
|
|
||||||
|
categories = response.results[0].categories
|
||||||
|
category_applied_input_types = response.results[0].category_applied_input_types
|
||||||
|
category_scores = response.results[0].category_scores
|
||||||
|
|
||||||
|
for i in [categories, category_applied_input_types, category_scores]:
|
||||||
|
if not set(required_categories).issubset(set(i.keys())):
|
||||||
|
raise ValueError(
|
||||||
|
f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}"
|
||||||
|
)
|
||||||
|
|
|
@ -22,7 +22,7 @@ from llama_stack.apis.safety import (
|
||||||
SafetyViolation,
|
SafetyViolation,
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.core.datatypes import Api
|
from llama_stack.core.datatypes import Api
|
||||||
from llama_stack.models.llama.datatypes import Role
|
from llama_stack.models.llama.datatypes import Role
|
||||||
|
@ -73,19 +73,19 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
|
||||||
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
|
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
|
||||||
|
|
||||||
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
|
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
|
||||||
"violence": [CAT_VIOLENT_CRIMES],
|
OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES],
|
||||||
"violence/graphic": [CAT_VIOLENT_CRIMES],
|
OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES],
|
||||||
"harassment": [CAT_CHILD_EXPLOITATION],
|
OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION],
|
||||||
"harassment/threatening": [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
|
OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
|
||||||
"hate": [CAT_HATE],
|
OpenAICategories.HATE: [CAT_HATE],
|
||||||
"hate/threatening": [CAT_HATE, CAT_VIOLENT_CRIMES],
|
OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES],
|
||||||
"illicit": [CAT_NON_VIOLENT_CRIMES],
|
OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES],
|
||||||
"illicit/violent": [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
|
OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
|
||||||
"sexual": [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
|
OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
|
||||||
"sexual/minors": [CAT_CHILD_EXPLOITATION],
|
OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION],
|
||||||
"self-harm": [CAT_SELF_HARM],
|
OpenAICategories.SELF_HARM: [CAT_SELF_HARM],
|
||||||
"self-harm/intent": [CAT_SELF_HARM],
|
OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM],
|
||||||
"self-harm/instructions": [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
|
OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
|
||||||
# These are custom categories that are not in the OpenAI moderation categories
|
# These are custom categories that are not in the OpenAI moderation categories
|
||||||
"custom/defamation": [CAT_DEFAMATION],
|
"custom/defamation": [CAT_DEFAMATION],
|
||||||
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],
|
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue