chore: support default model in moderations API

# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-10-22 15:01:43 -07:00
parent 7b90e0e9c8
commit f6098fa73a
23 changed files with 212 additions and 36 deletions

View file

@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import ListShieldsResponse, Shield
from llama_stack.core.datatypes import SafetyConfig
from llama_stack.core.routers.safety import SafetyRouter
async def test_run_moderation_uses_default_shield_when_model_missing():
routing_table = AsyncMock()
shield = Shield(
identifier="shield-1",
provider_resource_id=None,
provider_id="provider-id",
params={},
)
routing_table.list_shields.return_value = ListShieldsResponse(data=[shield])
moderation_response = ModerationObject(
id="mid",
model="shield-1",
results=[ModerationObjectResults(flagged=False)],
)
provider = AsyncMock()
provider.run_moderation.return_value = moderation_response
routing_table.get_provider_impl.return_value = provider
router = SafetyRouter(routing_table=routing_table, safety_config=SafetyConfig(default_shield_id="shield-1"))
result = await router.run_moderation("hello world")
assert result is moderation_response
routing_table.get_provider_impl.assert_awaited_once_with("shield-1")
provider.run_moderation.assert_awaited_once()
_, kwargs = provider.run_moderation.call_args
assert kwargs["model"] == "shield-1"
assert kwargs["input"] == "hello world"
async def test_run_moderation_prefers_provider_resource_id_when_available():
routing_table = AsyncMock()
shield = Shield(
identifier="shield-2",
provider_resource_id="provider/shield-model",
provider_id="provider-id",
params={},
)
routing_table.list_shields.return_value = ListShieldsResponse(data=[shield])
moderation_response = ModerationObject(
id="mid",
model="provider/shield-model",
results=[ModerationObjectResults(flagged=False)],
)
provider = AsyncMock()
provider.run_moderation.return_value = moderation_response
routing_table.get_provider_impl.return_value = provider
router = SafetyRouter(routing_table=routing_table, safety_config=SafetyConfig(default_shield_id="shield-2"))
result = await router.run_moderation("hello world")
assert result is moderation_response
routing_table.get_provider_impl.assert_awaited_once_with("shield-2")
provider.run_moderation.assert_awaited_once()
_, kwargs = provider.run_moderation.call_args
assert kwargs["model"] == "provider/shield-model"
assert kwargs["input"] == "hello world"