llama-stack-mirror/llama_stack/core/routers/safety.py
slekkala1 25e0553eed
chore: Change moderations api response to Provider returned categories (#3098)
# What does this PR do?
To be compliant with model policies for LLAMA, just return the
categories as is from provider, we will lose the OAI compat in
moderations api response.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
`SAFETY_MODEL=llama-guard3:8b LLAMA_STACK_CONFIG=starter uv run pytest
-v tests/integration/safety/test_safety.py
--text-model=llama3.2:3b-instruct-fp16
--embedding-model=all-MiniLM-L6-v2 --safety-shield=ollama`
2025-08-13 09:47:35 -07:00

86 lines
3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.inference import (
Message,
)
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
pass
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def unregister_shield(self, identifier: str) -> None:
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
return await self.routing_table.unregister_shield(identifier)
async def run_shield(
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_shield(
shield_id=shield_id,
messages=messages,
params=params,
)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
async def get_shield_id(self, model: str) -> str:
"""Get Shield id from model (provider_resource_id) of shield."""
list_shields_response = await self.routing_table.list_shields()
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
if not matches:
raise ValueError(f"No shield associated with provider_resource id {model}")
if len(matches) > 1:
raise ValueError(f"Multiple shields associated with provider_resource id {model}")
return matches[0]
shield_id = await get_shield_id(self, model)
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
response = await provider.run_moderation(
input=input,
model=model,
)
return response