chore: support default model in moderations API

# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-10-22 14:52:00 -07:00
parent 7b90e0e9c8
commit 3381874de9
23 changed files with 211 additions and 36 deletions

View file

@ -10,6 +10,7 @@ from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import SafetyConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
@ -20,9 +21,11 @@ class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
safety_config: SafetyConfig | None = None,
) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
self.safety_config = safety_config
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
@ -60,30 +63,47 @@ class SafetyRouter(Safety):
params=params,
)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
async def get_shield_id(self, model: str) -> str:
"""Get Shield id from model (provider_resource_id) of shield."""
list_shields_response = await self.routing_table.list_shields()
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
list_shields_response = await self.routing_table.list_shields()
shields = list_shields_response.data
matches: list[str] = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
selected_shield: Shield | None = None
provider_model: str | None = model
if model:
matches: list[Shield] = [s for s in shields if model == s.provider_resource_id]
if not matches:
raise ValueError(
f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in list_shields_response.data]}"
f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in shields]}"
)
if len(matches) > 1:
raise ValueError(
f"Multiple shields associated with provider_resource id {model}: matched shields {matches}"
f"Multiple shields associated with provider_resource id {model}: matched shields {[s.identifier for s in matches]}"
)
selected_shield = matches[0]
else:
default_shield_id = self.safety_config.default_shield_id if self.safety_config else None
if not default_shield_id:
raise ValueError(
"No moderation model specified and no default_shield_id configured in safety config: select model "
f"from {[s.provider_resource_id or s.identifier for s in shields]}"
)
return matches[0]
shield_id = await get_shield_id(self, model)
selected_shield = next((s for s in shields if s.identifier == default_shield_id), None)
if selected_shield is None:
raise ValueError(
f"Configured default_shield_id '{default_shield_id}' not found. Available shields: {[s.identifier for s in shields]}"
)
provider_model = selected_shield.provider_resource_id or selected_shield.identifier
shield_id = selected_shield.identifier
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
response = await provider.run_moderation(
input=input,
model=model,
model=provider_model,
)
return response