update as per comments

This commit is contained in:
Swapna Lekkala 2025-08-05 11:29:37 -07:00
parent 2d608ddd3b
commit 13fb8126a1
3 changed files with 65 additions and 48 deletions

View file

@ -75,7 +75,7 @@ class SafetyRouter(Safety):
return matches[0]
shield_id = await get_shield_id(self, model)
logger.debug(f"SafetyRouter.create: {shield_id}")
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_moderation(

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import re
import uuid
from string import Template
@ -221,15 +222,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
return await impl.run(messages)
async def run_moderation(
self,
input: str | list[str],
model: str | None = None, # To replace with default model for llama-guard
) -> ModerationObject:
if model is None:
raise ValueError("Model cannot be None")
if not input or len(input) == 0:
raise ValueError("Input cannot be empty")
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
if isinstance(input, list):
messages = input.copy()
else:
@ -257,7 +250,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
safety_categories=safety_categories,
)
return await impl.run_create(messages)
return await impl.run_moderation(messages)
class LlamaGuardShield:
@ -406,11 +399,13 @@ class LlamaGuardShield:
raise ValueError(f"Unexpected response: {response}")
async def run_create(self, messages: list[Message]) -> ModerationObject:
async def run_moderation(self, messages: list[Message]) -> ModerationObject:
if not messages:
return self.create_moderation_object(self.model)
# TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages)
# TODO: llama-stack inference protocol has issues with non-streaming inference code
response = await self.inference_api.chat_completion(
model_id=self.model,
messages=[shield_input_message],
@ -420,54 +415,75 @@ class LlamaGuardShield:
content = content.strip()
return self.get_moderation_object(content)
def create_safe_moderation_object(self, model: str) -> ModerationObject:
"""Create a ModerationObject for safe content."""
def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject:
"""Create a ModerationObject for either safe or unsafe content.
Args:
model: The model name
unsafe_code: Optional comma-separated list of safety codes. If None, creates safe object.
Returns:
ModerationObject with appropriate configuration
"""
# Set default values for safe case
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0)
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
flagged = False
user_message = None
metadata = {}
return ModerationObject(
id=f"modr-{uuid.uuid4()}",
model=model,
results=[
ModerationObjectResults(
flagged=False,
categories=categories,
category_applied_input_types=category_applied_input_types,
category_scores=category_scores,
# Handle unsafe case
if unsafe_code:
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
if invalid_codes:
logging.warning(f"Invalid safety codes returned: {invalid_codes}")
# just returning safe object, as we don't know what the invalid codes can map to
return ModerationObject(
id=f"modr-{uuid.uuid4()}",
model=model,
results=[
ModerationObjectResults(
flagged=flagged,
categories=categories,
category_applied_input_types=category_applied_input_types,
category_scores=category_scores,
user_message=user_message,
metadata=metadata,
)
],
)
],
)
def create_unsafe_moderation_object(self, model: str, unsafe_code: str) -> ModerationObject:
"""Create a ModerationObject for unsafe content."""
# Get OpenAI categories for the unsafe codes
openai_categories = []
for code in unsafe_code_list:
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code]
openai_categories.extend(
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
)
unsafe_code_list = unsafe_code.split(",")
openai_categories = []
for code in unsafe_code_list:
if code not in SAFETY_CODE_TO_CATEGORIES_MAP:
raise ValueError(f"Unknown code: {code}")
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP.get(code)
openai_categories.extend(
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
)
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
# Update categories for unsafe content
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
category_applied_input_types = {
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP
}
flagged = True
user_message = CANNED_RESPONSE_TEXT
metadata = {"violation_type": unsafe_code_list}
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
category_applied_input_types = {
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()
}
return ModerationObject(
id=f"modr-{uuid.uuid4()}",
model=model,
results=[
ModerationObjectResults(
flagged=True,
flagged=flagged,
categories=categories,
category_applied_input_types=category_applied_input_types,
category_scores=category_scores,
user_message=CANNED_RESPONSE_TEXT,
metadata={"violation_type": unsafe_code_list},
user_message=user_message,
metadata=metadata,
)
],
)
@ -487,12 +503,12 @@ class LlamaGuardShield:
def get_moderation_object(self, response: str) -> ModerationObject:
response = response.strip()
if self.is_content_safe(response):
return self.create_safe_moderation_object(self.model)
return self.create_moderation_object(self.model)
unsafe_code = self.check_unsafe_response(response)
if not unsafe_code:
raise ValueError(f"Unexpected response: {response}")
if self.is_content_safe(response, unsafe_code):
return self.create_safe_moderation_object(self.model)
return self.create_moderation_object(self.model)
else:
return self.create_unsafe_moderation_object(self.model, unsafe_code)
return self.create_moderation_object(self.model, unsafe_code)

View file

@ -65,6 +65,7 @@ def test_safe_examples_with_run_moderation(client_with_models, shield_id):
"How many years can you be a president in the US?",
"Quels sont les principaux bienfaits de l'alimentation méditerranéenne?",
"Search for 3 best places to see in San Francisco",
"",
]
shield = [shield for shield in client_with_models.shields.list() if shield.identifier == shield_id][0]
model_id = shield.provider_resource_id