mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-26 09:15:40 +00:00
chore: support default model in moderations API (#3890)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Vector IO Integration Tests / test-matrix (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 2s
Test Llama Stack Build / build-single-provider (push) Failing after 3s
Test Llama Stack Build / generate-matrix (push) Successful in 5s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 3s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 7s
Test External API and Providers / test-external (venv) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 12s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
Test Llama Stack Build / build (push) Failing after 3s
Unit Tests / unit-tests (3.12) (push) Failing after 5s
UI Tests / ui-tests (22) (push) Successful in 41s
Pre-commit / pre-commit (push) Successful in 1m33s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Vector IO Integration Tests / test-matrix (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 2s
Test Llama Stack Build / build-single-provider (push) Failing after 3s
Test Llama Stack Build / generate-matrix (push) Successful in 5s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 3s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 7s
Test External API and Providers / test-external (venv) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 12s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
Test Llama Stack Build / build (push) Failing after 3s
Unit Tests / unit-tests (3.12) (push) Failing after 5s
UI Tests / ui-tests (22) (push) Successful in 41s
Pre-commit / pre-commit (push) Successful in 1m33s
# What does this PR do? https://platform.openai.com/docs/api-reference/moderations supports optional model parameter. This PR adds support for using moderations API with model=None if a default shield id is provided via safety config. ## Test Plan added tests manual test: ``` > SAFETY_MODEL='together/meta-llama/Llama-Guard-4-12B' uv run llama stack run starter > curl http://localhost:8321/v1/moderations \ -H "Content-Type: application/json" \ -d '{ "input": [ "hello" ] }' ```
This commit is contained in:
parent
d12e5f0999
commit
9916cb3b17
23 changed files with 189 additions and 36 deletions
|
|
@ -10,6 +10,7 @@ from llama_stack.apis.inference import Message
|
|||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import SafetyConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
|
|
@ -20,9 +21,11 @@ class SafetyRouter(Safety):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
safety_config: SafetyConfig | None = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing SafetyRouter")
|
||||
self.routing_table = routing_table
|
||||
self.safety_config = safety_config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("SafetyRouter.initialize")
|
||||
|
|
@ -60,30 +63,47 @@ class SafetyRouter(Safety):
|
|||
params=params,
|
||||
)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
async def get_shield_id(self, model: str) -> str:
|
||||
"""Get Shield id from model (provider_resource_id) of shield."""
|
||||
list_shields_response = await self.routing_table.list_shields()
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
list_shields_response = await self.routing_table.list_shields()
|
||||
shields = list_shields_response.data
|
||||
|
||||
matches: list[str] = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
|
||||
selected_shield: Shield | None = None
|
||||
provider_model: str | None = model
|
||||
|
||||
if model:
|
||||
matches: list[Shield] = [s for s in shields if model == s.provider_resource_id]
|
||||
if not matches:
|
||||
raise ValueError(
|
||||
f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in list_shields_response.data]}"
|
||||
f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in shields]}"
|
||||
)
|
||||
if len(matches) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple shields associated with provider_resource id {model}: matched shields {matches}"
|
||||
f"Multiple shields associated with provider_resource id {model}: matched shields {[s.identifier for s in matches]}"
|
||||
)
|
||||
selected_shield = matches[0]
|
||||
else:
|
||||
default_shield_id = self.safety_config.default_shield_id if self.safety_config else None
|
||||
if not default_shield_id:
|
||||
raise ValueError(
|
||||
"No moderation model specified and no default_shield_id configured in safety config: select model "
|
||||
f"from {[s.provider_resource_id or s.identifier for s in shields]}"
|
||||
)
|
||||
return matches[0]
|
||||
|
||||
shield_id = await get_shield_id(self, model)
|
||||
selected_shield = next((s for s in shields if s.identifier == default_shield_id), None)
|
||||
if selected_shield is None:
|
||||
raise ValueError(
|
||||
f"Default moderation model not found. Choose from {[s.provider_resource_id or s.identifier for s in shields]}."
|
||||
)
|
||||
|
||||
provider_model = selected_shield.provider_resource_id
|
||||
|
||||
shield_id = selected_shield.identifier
|
||||
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
|
||||
response = await provider.run_moderation(
|
||||
input=input,
|
||||
model=model,
|
||||
model=provider_model,
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue