mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
feat: Add open ai compatible moderations api
This commit is contained in:
parent
0caef40e0d
commit
c89fb40082
6 changed files with 549 additions and 0 deletions
|
|
@ -10,6 +10,7 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
)
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
|
@ -60,3 +61,24 @@ class SafetyRouter(Safety):
|
|||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
||||
async def create(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
async def get_shield_id(self, model: str) -> str:
|
||||
"""Get Shield id from model (provider_resource_id) of shield."""
|
||||
list_shields_response = await self.routing_table.list_shields()
|
||||
|
||||
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
|
||||
if not matches:
|
||||
raise ValueError(f"No shield associated with provider_resource id {model}")
|
||||
if len(matches) > 1:
|
||||
raise ValueError(f"Multiple shields associated with provider_resource id {model}")
|
||||
return matches[0]
|
||||
|
||||
shield_id = await get_shield_id(self, model)
|
||||
logger.debug(f"SafetyRouter.create: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
|
||||
return await provider.create(
|
||||
input=input,
|
||||
model=model,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue