mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-28 09:45:34 +00:00
chore: support default model in moderations API (#3890)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Vector IO Integration Tests / test-matrix (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 2s
Test Llama Stack Build / build-single-provider (push) Failing after 3s
Test Llama Stack Build / generate-matrix (push) Successful in 5s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 3s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 7s
Test External API and Providers / test-external (venv) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 12s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
Test Llama Stack Build / build (push) Failing after 3s
Unit Tests / unit-tests (3.12) (push) Failing after 5s
UI Tests / ui-tests (22) (push) Successful in 41s
Pre-commit / pre-commit (push) Successful in 1m33s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Vector IO Integration Tests / test-matrix (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 2s
Test Llama Stack Build / build-single-provider (push) Failing after 3s
Test Llama Stack Build / generate-matrix (push) Successful in 5s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 3s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 7s
Test External API and Providers / test-external (venv) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 12s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
Test Llama Stack Build / build (push) Failing after 3s
Unit Tests / unit-tests (3.12) (push) Failing after 5s
UI Tests / ui-tests (22) (push) Successful in 41s
Pre-commit / pre-commit (push) Successful in 1m33s
# What does this PR do? https://platform.openai.com/docs/api-reference/moderations supports optional model parameter. This PR adds support for using moderations API with model=None if a default shield id is provided via safety config. ## Test Plan added tests manual test: ``` > SAFETY_MODEL='together/meta-llama/Llama-Guard-4-12B' uv run llama stack run starter > curl http://localhost:8321/v1/moderations \ -H "Content-Type: application/json" \ -d '{ "input": [ "hello" ] }' ```
This commit is contained in:
parent
d12e5f0999
commit
9916cb3b17
23 changed files with 189 additions and 36 deletions
43
tests/unit/core/routers/test_safety_router.py
Normal file
43
tests/unit/core/routers/test_safety_router.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import ListShieldsResponse, Shield
|
||||
from llama_stack.core.datatypes import SafetyConfig
|
||||
from llama_stack.core.routers.safety import SafetyRouter
|
||||
|
||||
|
||||
async def test_run_moderation_uses_default_shield_when_model_missing():
|
||||
routing_table = AsyncMock()
|
||||
shield = Shield(
|
||||
identifier="shield-1",
|
||||
provider_resource_id="provider/shield-model",
|
||||
provider_id="provider-id",
|
||||
params={},
|
||||
)
|
||||
routing_table.list_shields.return_value = ListShieldsResponse(data=[shield])
|
||||
|
||||
moderation_response = ModerationObject(
|
||||
id="mid",
|
||||
model="shield-1",
|
||||
results=[ModerationObjectResults(flagged=False)],
|
||||
)
|
||||
provider = AsyncMock()
|
||||
provider.run_moderation.return_value = moderation_response
|
||||
routing_table.get_provider_impl.return_value = provider
|
||||
|
||||
router = SafetyRouter(routing_table=routing_table, safety_config=SafetyConfig(default_shield_id="shield-1"))
|
||||
|
||||
result = await router.run_moderation("hello world")
|
||||
|
||||
assert result is moderation_response
|
||||
routing_table.get_provider_impl.assert_awaited_once_with("shield-1")
|
||||
provider.run_moderation.assert_awaited_once()
|
||||
_, kwargs = provider.run_moderation.call_args
|
||||
assert kwargs["model"] == "provider/shield-model"
|
||||
assert kwargs["input"] == "hello world"
|
||||
|
|
@ -11,8 +11,9 @@ from unittest.mock import AsyncMock
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, ModelType
|
||||
from llama_stack.core.datatypes import QualifiedModel, StackRunConfig, StorageConfig, VectorStoresConfig
|
||||
from llama_stack.core.stack import validate_vector_stores_config
|
||||
from llama_stack.apis.shields import ListShieldsResponse, Shield
|
||||
from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackRunConfig, StorageConfig, VectorStoresConfig
|
||||
from llama_stack.core.stack import validate_safety_config, validate_vector_stores_config
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
|
|
@ -65,3 +66,37 @@ class TestVectorStoresValidation:
|
|||
)
|
||||
|
||||
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
|
||||
|
||||
|
||||
class TestSafetyConfigValidation:
|
||||
async def test_validate_success(self):
|
||||
safety_config = SafetyConfig(default_shield_id="shield-1")
|
||||
|
||||
shield = Shield(
|
||||
identifier="shield-1",
|
||||
provider_id="provider-x",
|
||||
provider_resource_id="model-x",
|
||||
params={},
|
||||
)
|
||||
|
||||
shields_impl = AsyncMock()
|
||||
shields_impl.list_shields.return_value = ListShieldsResponse(data=[shield])
|
||||
|
||||
await validate_safety_config(safety_config, {Api.shields: shields_impl, Api.safety: AsyncMock()})
|
||||
|
||||
async def test_validate_wrong_shield_id(self):
|
||||
safety_config = SafetyConfig(default_shield_id="wrong-shield-id")
|
||||
|
||||
shields_impl = AsyncMock()
|
||||
shields_impl.list_shields.return_value = ListShieldsResponse(
|
||||
data=[
|
||||
Shield(
|
||||
identifier="shield-1",
|
||||
provider_resource_id="model-x",
|
||||
provider_id="provider-x",
|
||||
params={},
|
||||
)
|
||||
]
|
||||
)
|
||||
with pytest.raises(ValueError, match="wrong-shield-id"):
|
||||
await validate_safety_config(safety_config, {Api.shields: shields_impl, Api.safety: AsyncMock()})
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue