mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
# What does this PR do? This PR adds Open AI Compatible moderations api. Currently only implementing for llama guard safety provider Image support, expand to other safety providers and Deprecation of run_shield will be next steps. ## Test Plan Added 2 new tests for safe/ unsafe text prompt examples for the new open ai compatible moderations api usage `SAFETY_MODEL=llama-guard3:8b LLAMA_STACK_CONFIG=starter uv run pytest -v tests/integration/safety/test_safety.py --text-model=llama3.2:3b-instruct-fp16 --embedding-model=all-MiniLM-L6-v2 --safety-shield=ollama` (Had some issue with previous PR https://github.com/meta-llama/llama-stack/pull/2994 while updating and accidentally close it , reopened new one )
101 lines
3.8 KiB
Python
101 lines
3.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any
|
|
|
|
from llama_stack.apis.inference import (
|
|
Message,
|
|
)
|
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
|
from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories
|
|
from llama_stack.apis.shields import Shield
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.datatypes import RoutingTable
|
|
|
|
logger = get_logger(name=__name__, category="core")
|
|
|
|
|
|
class SafetyRouter(Safety):
|
|
def __init__(
|
|
self,
|
|
routing_table: RoutingTable,
|
|
) -> None:
|
|
logger.debug("Initializing SafetyRouter")
|
|
self.routing_table = routing_table
|
|
|
|
async def initialize(self) -> None:
|
|
logger.debug("SafetyRouter.initialize")
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
logger.debug("SafetyRouter.shutdown")
|
|
pass
|
|
|
|
async def register_shield(
|
|
self,
|
|
shield_id: str,
|
|
provider_shield_id: str | None = None,
|
|
provider_id: str | None = None,
|
|
params: dict[str, Any] | None = None,
|
|
) -> Shield:
|
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
|
|
|
async def unregister_shield(self, identifier: str) -> None:
|
|
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
|
|
return await self.routing_table.unregister_shield(identifier)
|
|
|
|
async def run_shield(
|
|
self,
|
|
shield_id: str,
|
|
messages: list[Message],
|
|
params: dict[str, Any] = None,
|
|
) -> RunShieldResponse:
|
|
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
|
provider = await self.routing_table.get_provider_impl(shield_id)
|
|
return await provider.run_shield(
|
|
shield_id=shield_id,
|
|
messages=messages,
|
|
params=params,
|
|
)
|
|
|
|
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
|
async def get_shield_id(self, model: str) -> str:
|
|
"""Get Shield id from model (provider_resource_id) of shield."""
|
|
list_shields_response = await self.routing_table.list_shields()
|
|
|
|
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
|
|
if not matches:
|
|
raise ValueError(f"No shield associated with provider_resource id {model}")
|
|
if len(matches) > 1:
|
|
raise ValueError(f"Multiple shields associated with provider_resource id {model}")
|
|
return matches[0]
|
|
|
|
shield_id = await get_shield_id(self, model)
|
|
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
|
|
provider = await self.routing_table.get_provider_impl(shield_id)
|
|
|
|
response = await provider.run_moderation(
|
|
input=input,
|
|
model=model,
|
|
)
|
|
self._validate_required_categories_exist(response)
|
|
|
|
return response
|
|
|
|
def _validate_required_categories_exist(self, response: ModerationObject) -> None:
|
|
"""Validate the ProviderImpl response contains the required Open AI moderations categories."""
|
|
required_categories = list(map(str, OpenAICategories))
|
|
|
|
categories = response.results[0].categories
|
|
category_applied_input_types = response.results[0].category_applied_input_types
|
|
category_scores = response.results[0].category_scores
|
|
|
|
for i in [categories, category_applied_input_types, category_scores]:
|
|
if not set(required_categories).issubset(set(i.keys())):
|
|
raise ValueError(
|
|
f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}"
|
|
)
|