feat: Add open ai compatible moderations api

This commit is contained in:
Swapna Lekkala 2025-08-01 16:08:36 -07:00
parent bf6e411643
commit 0d07ba1a65

View file

@ -260,45 +260,6 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
return await impl.run_create(messages) return await impl.run_create(messages)
async def create(
self,
input: str | list[str],
model: str | None = None, # To replace with default model for llama-guard
) -> ModerationObject:
if model is None:
raise ValueError("Model cannot be None")
if not input or len(input) == 0:
raise ValueError("Input cannot be empty")
if isinstance(input, list):
messages = input.copy()
else:
messages = [input]
# convert to user messages format with role
messages = [UserMessage(content=m) for m in messages]
# Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
if model in LLAMA_GUARD_MODEL_IDS:
# Use the mapped model for categories but the original model_id for inference
mapped_model = LLAMA_GUARD_MODEL_IDS[model]
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
else:
# For unknown models, use default Llama Guard 3 8B categories
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
impl = LlamaGuardShield(
model=model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
safety_categories=safety_categories,
)
return await impl.run_create(messages)
class LlamaGuardShield: class LlamaGuardShield:
def __init__( def __init__(
self, self,