mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
feat: Add open ai compatible moderations api
This commit is contained in:
parent
bf6e411643
commit
0d07ba1a65
1 changed files with 0 additions and 39 deletions
|
@ -260,45 +260,6 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
return await impl.run_create(messages)
|
return await impl.run_create(messages)
|
||||||
|
|
||||||
|
|
||||||
async def create(
|
|
||||||
self,
|
|
||||||
input: str | list[str],
|
|
||||||
model: str | None = None, # To replace with default model for llama-guard
|
|
||||||
) -> ModerationObject:
|
|
||||||
if model is None:
|
|
||||||
raise ValueError("Model cannot be None")
|
|
||||||
if not input or len(input) == 0:
|
|
||||||
raise ValueError("Input cannot be empty")
|
|
||||||
if isinstance(input, list):
|
|
||||||
messages = input.copy()
|
|
||||||
else:
|
|
||||||
messages = [input]
|
|
||||||
|
|
||||||
# convert to user messages format with role
|
|
||||||
messages = [UserMessage(content=m) for m in messages]
|
|
||||||
|
|
||||||
# Use the inference API's model resolution instead of hardcoded mappings
|
|
||||||
# This allows the shield to work with any registered model
|
|
||||||
# Determine safety categories based on the model type
|
|
||||||
# For known Llama Guard models, use specific categories
|
|
||||||
if model in LLAMA_GUARD_MODEL_IDS:
|
|
||||||
# Use the mapped model for categories but the original model_id for inference
|
|
||||||
mapped_model = LLAMA_GUARD_MODEL_IDS[model]
|
|
||||||
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
|
|
||||||
else:
|
|
||||||
# For unknown models, use default Llama Guard 3 8B categories
|
|
||||||
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
|
||||||
|
|
||||||
impl = LlamaGuardShield(
|
|
||||||
model=model,
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
excluded_categories=self.config.excluded_categories,
|
|
||||||
safety_categories=safety_categories,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await impl.run_create(messages)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardShield:
|
class LlamaGuardShield:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue