mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
update as per comments
This commit is contained in:
parent
2d608ddd3b
commit
13fb8126a1
3 changed files with 65 additions and 48 deletions
|
@ -75,7 +75,7 @@ class SafetyRouter(Safety):
|
|||
return matches[0]
|
||||
|
||||
shield_id = await get_shield_id(self, model)
|
||||
logger.debug(f"SafetyRouter.create: {shield_id}")
|
||||
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
|
||||
return await provider.run_moderation(
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from string import Template
|
||||
|
@ -221,15 +222,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
|
||||
return await impl.run(messages)
|
||||
|
||||
async def run_moderation(
|
||||
self,
|
||||
input: str | list[str],
|
||||
model: str | None = None, # To replace with default model for llama-guard
|
||||
) -> ModerationObject:
|
||||
if model is None:
|
||||
raise ValueError("Model cannot be None")
|
||||
if not input or len(input) == 0:
|
||||
raise ValueError("Input cannot be empty")
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
if isinstance(input, list):
|
||||
messages = input.copy()
|
||||
else:
|
||||
|
@ -257,7 +250,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
safety_categories=safety_categories,
|
||||
)
|
||||
|
||||
return await impl.run_create(messages)
|
||||
return await impl.run_moderation(messages)
|
||||
|
||||
|
||||
class LlamaGuardShield:
|
||||
|
@ -406,11 +399,13 @@ class LlamaGuardShield:
|
|||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
async def run_create(self, messages: list[Message]) -> ModerationObject:
|
||||
async def run_moderation(self, messages: list[Message]) -> ModerationObject:
|
||||
if not messages:
|
||||
return self.create_moderation_object(self.model)
|
||||
|
||||
# TODO: Add Image based support for OpenAI Moderations
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=self.model,
|
||||
messages=[shield_input_message],
|
||||
|
@ -420,54 +415,75 @@ class LlamaGuardShield:
|
|||
content = content.strip()
|
||||
return self.get_moderation_object(content)
|
||||
|
||||
def create_safe_moderation_object(self, model: str) -> ModerationObject:
|
||||
"""Create a ModerationObject for safe content."""
|
||||
def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject:
|
||||
"""Create a ModerationObject for either safe or unsafe content.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
unsafe_code: Optional comma-separated list of safety codes. If None, creates safe object.
|
||||
|
||||
Returns:
|
||||
ModerationObject with appropriate configuration
|
||||
"""
|
||||
# Set default values for safe case
|
||||
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
|
||||
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0)
|
||||
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
|
||||
flagged = False
|
||||
user_message = None
|
||||
metadata = {}
|
||||
|
||||
return ModerationObject(
|
||||
id=f"modr-{uuid.uuid4()}",
|
||||
model=model,
|
||||
results=[
|
||||
ModerationObjectResults(
|
||||
flagged=False,
|
||||
categories=categories,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
category_scores=category_scores,
|
||||
# Handle unsafe case
|
||||
if unsafe_code:
|
||||
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
|
||||
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
|
||||
if invalid_codes:
|
||||
logging.warning(f"Invalid safety codes returned: {invalid_codes}")
|
||||
# just returning safe object, as we don't know what the invalid codes can map to
|
||||
return ModerationObject(
|
||||
id=f"modr-{uuid.uuid4()}",
|
||||
model=model,
|
||||
results=[
|
||||
ModerationObjectResults(
|
||||
flagged=flagged,
|
||||
categories=categories,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
category_scores=category_scores,
|
||||
user_message=user_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def create_unsafe_moderation_object(self, model: str, unsafe_code: str) -> ModerationObject:
|
||||
"""Create a ModerationObject for unsafe content."""
|
||||
# Get OpenAI categories for the unsafe codes
|
||||
openai_categories = []
|
||||
for code in unsafe_code_list:
|
||||
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code]
|
||||
openai_categories.extend(
|
||||
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
|
||||
)
|
||||
|
||||
unsafe_code_list = unsafe_code.split(",")
|
||||
openai_categories = []
|
||||
for code in unsafe_code_list:
|
||||
if code not in SAFETY_CODE_TO_CATEGORIES_MAP:
|
||||
raise ValueError(f"Unknown code: {code}")
|
||||
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP.get(code)
|
||||
openai_categories.extend(
|
||||
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
|
||||
)
|
||||
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
|
||||
# Update categories for unsafe content
|
||||
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
|
||||
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
|
||||
category_applied_input_types = {
|
||||
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP
|
||||
}
|
||||
flagged = True
|
||||
user_message = CANNED_RESPONSE_TEXT
|
||||
metadata = {"violation_type": unsafe_code_list}
|
||||
|
||||
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
|
||||
category_applied_input_types = {
|
||||
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()
|
||||
}
|
||||
return ModerationObject(
|
||||
id=f"modr-{uuid.uuid4()}",
|
||||
model=model,
|
||||
results=[
|
||||
ModerationObjectResults(
|
||||
flagged=True,
|
||||
flagged=flagged,
|
||||
categories=categories,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
category_scores=category_scores,
|
||||
user_message=CANNED_RESPONSE_TEXT,
|
||||
metadata={"violation_type": unsafe_code_list},
|
||||
user_message=user_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -487,12 +503,12 @@ class LlamaGuardShield:
|
|||
def get_moderation_object(self, response: str) -> ModerationObject:
|
||||
response = response.strip()
|
||||
if self.is_content_safe(response):
|
||||
return self.create_safe_moderation_object(self.model)
|
||||
return self.create_moderation_object(self.model)
|
||||
unsafe_code = self.check_unsafe_response(response)
|
||||
if not unsafe_code:
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
if self.is_content_safe(response, unsafe_code):
|
||||
return self.create_safe_moderation_object(self.model)
|
||||
return self.create_moderation_object(self.model)
|
||||
else:
|
||||
return self.create_unsafe_moderation_object(self.model, unsafe_code)
|
||||
return self.create_moderation_object(self.model, unsafe_code)
|
||||
|
|
|
@ -65,6 +65,7 @@ def test_safe_examples_with_run_moderation(client_with_models, shield_id):
|
|||
"How many years can you be a president in the US?",
|
||||
"Quels sont les principaux bienfaits de l'alimentation méditerranéenne?",
|
||||
"Search for 3 best places to see in San Francisco",
|
||||
"",
|
||||
]
|
||||
shield = [shield for shield in client_with_models.shields.list() if shield.identifier == shield_id][0]
|
||||
model_id = shield.provider_resource_id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue