From 13fb8126a12b9d3a2f05d8f489674e22da291d1a Mon Sep 17 00:00:00 2001 From: Swapna Lekkala Date: Tue, 5 Aug 2025 11:29:37 -0700 Subject: [PATCH] update as per comments --- llama_stack/core/routers/safety.py | 2 +- .../inline/safety/llama_guard/llama_guard.py | 110 ++++++++++-------- tests/integration/safety/test_safety.py | 1 + 3 files changed, 65 insertions(+), 48 deletions(-) diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index c76e3f175..b8bce274e 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -75,7 +75,7 @@ class SafetyRouter(Safety): return matches[0] shield_id = await get_shield_id(self, model) - logger.debug(f"SafetyRouter.create: {shield_id}") + logger.debug(f"SafetyRouter.run_moderation: {shield_id}") provider = await self.routing_table.get_provider_impl(shield_id) return await provider.run_moderation( diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 713e5fa00..8aac443be 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging import re import uuid from string import Template @@ -221,15 +222,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): return await impl.run(messages) - async def run_moderation( - self, - input: str | list[str], - model: str | None = None, # To replace with default model for llama-guard - ) -> ModerationObject: - if model is None: - raise ValueError("Model cannot be None") - if not input or len(input) == 0: - raise ValueError("Input cannot be empty") + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: if isinstance(input, list): messages = input.copy() else: @@ -257,7 +250,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): safety_categories=safety_categories, ) - return await impl.run_create(messages) + return await impl.run_moderation(messages) class LlamaGuardShield: @@ -406,11 +399,13 @@ class LlamaGuardShield: raise ValueError(f"Unexpected response: {response}") - async def run_create(self, messages: list[Message]) -> ModerationObject: + async def run_moderation(self, messages: list[Message]) -> ModerationObject: + if not messages: + return self.create_moderation_object(self.model) + # TODO: Add Image based support for OpenAI Moderations shield_input_message = self.build_text_shield_input(messages) - # TODO: llama-stack inference protocol has issues with non-streaming inference code response = await self.inference_api.chat_completion( model_id=self.model, messages=[shield_input_message], @@ -420,54 +415,75 @@ class LlamaGuardShield: content = content.strip() return self.get_moderation_object(content) - def create_safe_moderation_object(self, model: str) -> ModerationObject: - """Create a ModerationObject for safe content.""" + def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject: + """Create a ModerationObject for either safe or unsafe content. + + Args: + model: The model name + unsafe_code: Optional comma-separated list of safety codes. If None, creates safe object. + + Returns: + ModerationObject with appropriate configuration + """ + # Set default values for safe case categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False) category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0) category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()} + flagged = False + user_message = None + metadata = {} - return ModerationObject( - id=f"modr-{uuid.uuid4()}", - model=model, - results=[ - ModerationObjectResults( - flagged=False, - categories=categories, - category_applied_input_types=category_applied_input_types, - category_scores=category_scores, + # Handle unsafe case + if unsafe_code: + unsafe_code_list = [code.strip() for code in unsafe_code.split(",")] + invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP] + if invalid_codes: + logging.warning(f"Invalid safety codes returned: {invalid_codes}") + # just returning safe object, as we don't know what the invalid codes can map to + return ModerationObject( + id=f"modr-{uuid.uuid4()}", + model=model, + results=[ + ModerationObjectResults( + flagged=flagged, + categories=categories, + category_applied_input_types=category_applied_input_types, + category_scores=category_scores, + user_message=user_message, + metadata=metadata, + ) + ], ) - ], - ) - def create_unsafe_moderation_object(self, model: str, unsafe_code: str) -> ModerationObject: - """Create a ModerationObject for unsafe content.""" + # Get OpenAI categories for the unsafe codes + openai_categories = [] + for code in unsafe_code_list: + llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code] + openai_categories.extend( + k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l + ) - unsafe_code_list = unsafe_code.split(",") - openai_categories = [] - for code in unsafe_code_list: - if code not in SAFETY_CODE_TO_CATEGORIES_MAP: - raise ValueError(f"Unknown code: {code}") - llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP.get(code) - openai_categories.extend( - k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l - ) - categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP} + # Update categories for unsafe content + categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP} + category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP} + category_applied_input_types = { + k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP + } + flagged = True + user_message = CANNED_RESPONSE_TEXT + metadata = {"violation_type": unsafe_code_list} - category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()} - category_applied_input_types = { - k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys() - } return ModerationObject( id=f"modr-{uuid.uuid4()}", model=model, results=[ ModerationObjectResults( - flagged=True, + flagged=flagged, categories=categories, category_applied_input_types=category_applied_input_types, category_scores=category_scores, - user_message=CANNED_RESPONSE_TEXT, - metadata={"violation_type": unsafe_code_list}, + user_message=user_message, + metadata=metadata, ) ], ) @@ -487,12 +503,12 @@ class LlamaGuardShield: def get_moderation_object(self, response: str) -> ModerationObject: response = response.strip() if self.is_content_safe(response): - return self.create_safe_moderation_object(self.model) + return self.create_moderation_object(self.model) unsafe_code = self.check_unsafe_response(response) if not unsafe_code: raise ValueError(f"Unexpected response: {response}") if self.is_content_safe(response, unsafe_code): - return self.create_safe_moderation_object(self.model) + return self.create_moderation_object(self.model) else: - return self.create_unsafe_moderation_object(self.model, unsafe_code) + return self.create_moderation_object(self.model, unsafe_code) diff --git a/tests/integration/safety/test_safety.py b/tests/integration/safety/test_safety.py index 260ef0016..11fd828bc 100644 --- a/tests/integration/safety/test_safety.py +++ b/tests/integration/safety/test_safety.py @@ -65,6 +65,7 @@ def test_safe_examples_with_run_moderation(client_with_models, shield_id): "How many years can you be a president in the US?", "Quels sont les principaux bienfaits de l'alimentation méditerranéenne?", "Search for 3 best places to see in San Francisco", + "", ] shield = [shield for shield in client_with_models.shields.list() if shield.identifier == shield_id][0] model_id = shield.provider_resource_id