feat: Add open ai compatible moderations api

This commit is contained in:
Swapna Lekkala 2025-08-01 15:54:00 -07:00
parent 0caef40e0d
commit c89fb40082
6 changed files with 549 additions and 0 deletions

View file

@ -10,6 +10,7 @@ from llama_stack.apis.inference import (
Message,
)
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
@ -60,3 +61,24 @@ class SafetyRouter(Safety):
messages=messages,
params=params,
)
async def create(self, input: str | list[str], model: str) -> ModerationObject:
async def get_shield_id(self, model: str) -> str:
"""Get Shield id from model (provider_resource_id) of shield."""
list_shields_response = await self.routing_table.list_shields()
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
if not matches:
raise ValueError(f"No shield associated with provider_resource id {model}")
if len(matches) > 1:
raise ValueError(f"Multiple shields associated with provider_resource id {model}")
return matches[0]
shield_id = await get_shield_id(self, model)
logger.debug(f"SafetyRouter.create: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.create(
input=input,
model=model,
)