update as per comments

This commit is contained in:
Swapna Lekkala 2025-08-05 11:29:37 -07:00
parent 2d608ddd3b
commit 13fb8126a1
3 changed files with 65 additions and 48 deletions

View file

@ -75,7 +75,7 @@ class SafetyRouter(Safety):
return matches[0] return matches[0]
shield_id = await get_shield_id(self, model) shield_id = await get_shield_id(self, model)
logger.debug(f"SafetyRouter.create: {shield_id}") logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id) provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_moderation( return await provider.run_moderation(

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import re import re
import uuid import uuid
from string import Template from string import Template
@ -221,15 +222,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
return await impl.run(messages) return await impl.run(messages)
async def run_moderation( async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
self,
input: str | list[str],
model: str | None = None, # To replace with default model for llama-guard
) -> ModerationObject:
if model is None:
raise ValueError("Model cannot be None")
if not input or len(input) == 0:
raise ValueError("Input cannot be empty")
if isinstance(input, list): if isinstance(input, list):
messages = input.copy() messages = input.copy()
else: else:
@ -257,7 +250,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
safety_categories=safety_categories, safety_categories=safety_categories,
) )
return await impl.run_create(messages) return await impl.run_moderation(messages)
class LlamaGuardShield: class LlamaGuardShield:
@ -406,11 +399,13 @@ class LlamaGuardShield:
raise ValueError(f"Unexpected response: {response}") raise ValueError(f"Unexpected response: {response}")
async def run_create(self, messages: list[Message]) -> ModerationObject: async def run_moderation(self, messages: list[Message]) -> ModerationObject:
if not messages:
return self.create_moderation_object(self.model)
# TODO: Add Image based support for OpenAI Moderations # TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages) shield_input_message = self.build_text_shield_input(messages)
# TODO: llama-stack inference protocol has issues with non-streaming inference code
response = await self.inference_api.chat_completion( response = await self.inference_api.chat_completion(
model_id=self.model, model_id=self.model,
messages=[shield_input_message], messages=[shield_input_message],
@ -420,54 +415,75 @@ class LlamaGuardShield:
content = content.strip() content = content.strip()
return self.get_moderation_object(content) return self.get_moderation_object(content)
def create_safe_moderation_object(self, model: str) -> ModerationObject: def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject:
"""Create a ModerationObject for safe content.""" """Create a ModerationObject for either safe or unsafe content.
Args:
model: The model name
unsafe_code: Optional comma-separated list of safety codes. If None, creates safe object.
Returns:
ModerationObject with appropriate configuration
"""
# Set default values for safe case
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False) categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0) category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0)
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()} category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
flagged = False
user_message = None
metadata = {}
# Handle unsafe case
if unsafe_code:
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
if invalid_codes:
logging.warning(f"Invalid safety codes returned: {invalid_codes}")
# just returning safe object, as we don't know what the invalid codes can map to
return ModerationObject( return ModerationObject(
id=f"modr-{uuid.uuid4()}", id=f"modr-{uuid.uuid4()}",
model=model, model=model,
results=[ results=[
ModerationObjectResults( ModerationObjectResults(
flagged=False, flagged=flagged,
categories=categories, categories=categories,
category_applied_input_types=category_applied_input_types, category_applied_input_types=category_applied_input_types,
category_scores=category_scores, category_scores=category_scores,
user_message=user_message,
metadata=metadata,
) )
], ],
) )
def create_unsafe_moderation_object(self, model: str, unsafe_code: str) -> ModerationObject: # Get OpenAI categories for the unsafe codes
"""Create a ModerationObject for unsafe content."""
unsafe_code_list = unsafe_code.split(",")
openai_categories = [] openai_categories = []
for code in unsafe_code_list: for code in unsafe_code_list:
if code not in SAFETY_CODE_TO_CATEGORIES_MAP: llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code]
raise ValueError(f"Unknown code: {code}")
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP.get(code)
openai_categories.extend( openai_categories.extend(
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
) )
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()} # Update categories for unsafe content
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
category_applied_input_types = { category_applied_input_types = {
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys() k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP
} }
flagged = True
user_message = CANNED_RESPONSE_TEXT
metadata = {"violation_type": unsafe_code_list}
return ModerationObject( return ModerationObject(
id=f"modr-{uuid.uuid4()}", id=f"modr-{uuid.uuid4()}",
model=model, model=model,
results=[ results=[
ModerationObjectResults( ModerationObjectResults(
flagged=True, flagged=flagged,
categories=categories, categories=categories,
category_applied_input_types=category_applied_input_types, category_applied_input_types=category_applied_input_types,
category_scores=category_scores, category_scores=category_scores,
user_message=CANNED_RESPONSE_TEXT, user_message=user_message,
metadata={"violation_type": unsafe_code_list}, metadata=metadata,
) )
], ],
) )
@ -487,12 +503,12 @@ class LlamaGuardShield:
def get_moderation_object(self, response: str) -> ModerationObject: def get_moderation_object(self, response: str) -> ModerationObject:
response = response.strip() response = response.strip()
if self.is_content_safe(response): if self.is_content_safe(response):
return self.create_safe_moderation_object(self.model) return self.create_moderation_object(self.model)
unsafe_code = self.check_unsafe_response(response) unsafe_code = self.check_unsafe_response(response)
if not unsafe_code: if not unsafe_code:
raise ValueError(f"Unexpected response: {response}") raise ValueError(f"Unexpected response: {response}")
if self.is_content_safe(response, unsafe_code): if self.is_content_safe(response, unsafe_code):
return self.create_safe_moderation_object(self.model) return self.create_moderation_object(self.model)
else: else:
return self.create_unsafe_moderation_object(self.model, unsafe_code) return self.create_moderation_object(self.model, unsafe_code)

View file

@ -65,6 +65,7 @@ def test_safe_examples_with_run_moderation(client_with_models, shield_id):
"How many years can you be a president in the US?", "How many years can you be a president in the US?",
"Quels sont les principaux bienfaits de l'alimentation méditerranéenne?", "Quels sont les principaux bienfaits de l'alimentation méditerranéenne?",
"Search for 3 best places to see in San Francisco", "Search for 3 best places to see in San Francisco",
"",
] ]
shield = [shield for shield in client_with_models.shields.list() if shield.identifier == shield_id][0] shield = [shield for shield in client_with_models.shields.list() if shield.identifier == shield_id][0]
model_id = shield.provider_resource_id model_id = shield.provider_resource_id