diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index bd22f2129..bd2d4b7a4 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -6443,11 +6443,10 @@ components: model: type: string description: >- - The content moderation model you would like to use. + (Optional) The content moderation model you would like to use. additionalProperties: false required: - input - - model title: RunModerationRequest ModerationObject: type: object diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index e3e182dd7..8bd8ecf3f 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -8185,13 +8185,12 @@ }, "model": { "type": "string", - "description": "The content moderation model you would like to use." + "description": "(Optional) The content moderation model you would like to use." } }, "additionalProperties": false, "required": [ - "input", - "model" + "input" ], "title": "RunModerationRequest" }, diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 6b5b8230a..cd86239e8 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -6104,11 +6104,10 @@ components: model: type: string description: >- - The content moderation model you would like to use. + (Optional) The content moderation model you would like to use. additionalProperties: false required: - input - - model title: RunModerationRequest ModerationObject: type: object diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 384770954..d9dbe27c9 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -6919,13 +6919,12 @@ }, "model": { "type": "string", - "description": "The content moderation model you would like to use." + "description": "(Optional) The content moderation model you would like to use." } }, "additionalProperties": false, "required": [ - "input", - "model" + "input" ], "title": "RunModerationRequest" }, diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 36b9c7153..604a4eace 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -5230,11 +5230,10 @@ components: model: type: string description: >- - The content moderation model you would like to use. + (Optional) The content moderation model you would like to use. additionalProperties: false required: - input - - model title: RunModerationRequest ModerationObject: type: object diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 77a64ced0..687c35db8 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -8591,13 +8591,12 @@ }, "model": { "type": "string", - "description": "The content moderation model you would like to use." + "description": "(Optional) The content moderation model you would like to use." } }, "additionalProperties": false, "required": [ - "input", - "model" + "input" ], "title": "RunModerationRequest" }, diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index bd22f2129..bd2d4b7a4 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -6443,11 +6443,10 @@ components: model: type: string description: >- - The content moderation model you would like to use. + (Optional) The content moderation model you would like to use. additionalProperties: false required: - input - - model title: RunModerationRequest ModerationObject: type: object diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index eaaa937d3..f6d51871b 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -123,13 +123,13 @@ class Safety(Protocol): @webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) @webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1) - async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject: """Create moderation. Classifies if text and/or image inputs are potentially harmful. :param input: Input (or inputs) to classify. Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models. - :param model: The content moderation model you would like to use. + :param model: (Optional) The content moderation model you would like to use. :returns: A moderation object. """ ... diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index 5f4775d87..172bc17b8 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -374,6 +374,15 @@ class VectorStoresConfig(BaseModel): ) +class SafetyConfig(BaseModel): + """Configuration for default moderations model.""" + + default_shield_id: str | None = Field( + default=None, + description="ID of the shield to use for when `model` is not specified in the `moderations` API request.", + ) + + class QuotaPeriod(StrEnum): DAY = "day" @@ -532,6 +541,11 @@ can be instantiated multiple times (with different configs) if necessary. description="Configuration for vector stores, including default embedding model", ) + safety: SafetyConfig | None = Field( + default=None, + description="Configuration for default moderations model", + ) + @field_validator("external_providers_dir") @classmethod def validate_external_providers_dir(cls, v): diff --git a/llama_stack/core/routers/__init__.py b/llama_stack/core/routers/__init__.py index 20c17e59d..2f35fe04f 100644 --- a/llama_stack/core/routers/__init__.py +++ b/llama_stack/core/routers/__init__.py @@ -95,6 +95,8 @@ async def get_auto_router_impl( elif api == Api.vector_io: api_to_dep_impl["vector_stores_config"] = run_config.vector_stores + elif api == Api.safety: + api_to_dep_impl["safety_config"] = run_config.safety impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) await impl.initialize() diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index d171c9721..79eac8b46 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -10,6 +10,7 @@ from llama_stack.apis.inference import Message from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield +from llama_stack.core.datatypes import SafetyConfig from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable @@ -20,9 +21,11 @@ class SafetyRouter(Safety): def __init__( self, routing_table: RoutingTable, + safety_config: SafetyConfig | None = None, ) -> None: logger.debug("Initializing SafetyRouter") self.routing_table = routing_table + self.safety_config = safety_config async def initialize(self) -> None: logger.debug("SafetyRouter.initialize") @@ -60,30 +63,47 @@ class SafetyRouter(Safety): params=params, ) - async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: - async def get_shield_id(self, model: str) -> str: - """Get Shield id from model (provider_resource_id) of shield.""" - list_shields_response = await self.routing_table.list_shields() + async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject: + list_shields_response = await self.routing_table.list_shields() + shields = list_shields_response.data - matches: list[str] = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id] + selected_shield: Shield | None = None + provider_model: str | None = model + if model: + matches: list[Shield] = [s for s in shields if model == s.provider_resource_id] if not matches: raise ValueError( - f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in list_shields_response.data]}" + f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in shields]}" ) if len(matches) > 1: raise ValueError( - f"Multiple shields associated with provider_resource id {model}: matched shields {matches}" + f"Multiple shields associated with provider_resource id {model}: matched shields {[s.identifier for s in matches]}" + ) + selected_shield = matches[0] + else: + default_shield_id = self.safety_config.default_shield_id if self.safety_config else None + if not default_shield_id: + raise ValueError( + "No moderation model specified and no default_shield_id configured in safety config: select model " + f"from {[s.provider_resource_id or s.identifier for s in shields]}" ) - return matches[0] - shield_id = await get_shield_id(self, model) + selected_shield = next((s for s in shields if s.identifier == default_shield_id), None) + if selected_shield is None: + raise ValueError( + f"Default moderation model not found. Choose from {[s.provider_resource_id or s.identifier for s in shields]}." + ) + + provider_model = selected_shield.provider_resource_id + + shield_id = selected_shield.identifier logger.debug(f"SafetyRouter.run_moderation: {shield_id}") provider = await self.routing_table.get_provider_impl(shield_id) response = await provider.run_moderation( input=input, - model=model, + model=provider_model, ) return response diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 4cf1d072d..ebfd59a05 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -35,7 +35,7 @@ from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl -from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig +from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig from llama_stack.core.distribution import get_provider_registry from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl @@ -175,6 +175,30 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})") +async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]): + if safety_config is None or safety_config.default_shield_id is None: + return + + if Api.shields not in impls: + raise ValueError("Safety configuration requires the shields API to be enabled") + + if Api.safety not in impls: + raise ValueError("Safety configuration requires the safety API to be enabled") + + shields_impl = impls[Api.shields] + response = await shields_impl.list_shields() + shields_by_id = {shield.identifier: shield for shield in response.data} + + default_shield_id = safety_config.default_shield_id + # don't validate if there are no shields registered + if shields_by_id and default_shield_id not in shields_by_id: + available = sorted(shields_by_id) + raise ValueError( + f"Configured default_shield_id '{default_shield_id}' not found among registered shields." + f" Available shields: {available}" + ) + + class EnvVarError(Exception): def __init__(self, var_name: str, path: str = ""): self.var_name = var_name @@ -412,6 +436,7 @@ class Stack: await register_resources(self.run_config, impls) await refresh_registry_once(impls) await validate_vector_stores_config(self.run_config.vector_stores, impls) + await validate_safety_config(self.run_config.safety, impls) self.impls = impls def create_registry_refresh_task(self): diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index ecf9eed3b..ed880d4a0 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -274,3 +274,5 @@ vector_stores: default_embedding_model: provider_id: sentence-transformers model_id: nomic-ai/nomic-embed-text-v1.5 +safety: + default_shield_id: llama-guard diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml index 92483c78e..33e8c9b59 100644 --- a/llama_stack/distributions/starter-gpu/run.yaml +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -277,3 +277,5 @@ vector_stores: default_embedding_model: provider_id: sentence-transformers model_id: nomic-ai/nomic-embed-text-v1.5 +safety: + default_shield_id: llama-guard diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 3b9d8f890..4ca0914af 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -274,3 +274,5 @@ vector_stores: default_embedding_model: provider_id: sentence-transformers model_id: nomic-ai/nomic-embed-text-v1.5 +safety: + default_shield_id: llama-guard diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index c8c7101a6..49b7a2463 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -12,6 +12,7 @@ from llama_stack.core.datatypes import ( Provider, ProviderSpec, QualifiedModel, + SafetyConfig, ShieldInput, ToolGroupInput, VectorStoresConfig, @@ -256,6 +257,9 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate: model_id="nomic-ai/nomic-embed-text-v1.5", ), ), + safety_config=SafetyConfig( + default_shield_id="llama-guard", + ), ), }, run_config_env_vars={ diff --git a/llama_stack/distributions/template.py b/llama_stack/distributions/template.py index 64f21e626..f0c4c6b9e 100644 --- a/llama_stack/distributions/template.py +++ b/llama_stack/distributions/template.py @@ -24,6 +24,7 @@ from llama_stack.core.datatypes import ( DistributionSpec, ModelInput, Provider, + SafetyConfig, ShieldInput, TelemetryConfig, ToolGroupInput, @@ -188,6 +189,7 @@ class RunConfigSettings(BaseModel): default_datasets: list[DatasetInput] | None = None default_benchmarks: list[BenchmarkInput] | None = None vector_stores_config: VectorStoresConfig | None = None + safety_config: SafetyConfig | None = None telemetry: TelemetryConfig = Field(default_factory=lambda: TelemetryConfig(enabled=True)) storage_backends: dict[str, Any] | None = None storage_stores: dict[str, Any] | None = None @@ -290,6 +292,9 @@ class RunConfigSettings(BaseModel): if self.vector_stores_config: config["vector_stores"] = self.vector_stores_config.model_dump(exclude_none=True) + if self.safety_config: + config["safety"] = self.safety_config.model_dump(exclude_none=True) + return config diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index e1cd8c5e4..7da9ea0d7 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -101,7 +101,10 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): metadata=metadata, ) - async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject: + if model is None: + raise ValueError("Code scanner moderation requires a model identifier.") + inputs = input if isinstance(input, list) else [input] results = [] diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 47c6ccbed..6f6346e82 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -200,7 +200,10 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): return await impl.run(messages) - async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject: + if model is None: + raise ValueError("Llama Guard moderation requires a model identifier.") + if isinstance(input, list): messages = input.copy() else: diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 8ca96300f..2015e1150 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -63,7 +63,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): return await self.shield.run(messages) - async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject: raise NotImplementedError("run_moderation is not implemented for Prompt Guard") diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index c0df8f095..236f16207 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -66,7 +66,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): self.shield = NeMoGuardrails(self.config, shield.shield_id) return await self.shield.run(messages) - async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject: raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation") diff --git a/tests/unit/core/routers/test_safety_router.py b/tests/unit/core/routers/test_safety_router.py new file mode 100644 index 000000000..bf195ff33 --- /dev/null +++ b/tests/unit/core/routers/test_safety_router.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock + +from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults +from llama_stack.apis.shields import ListShieldsResponse, Shield +from llama_stack.core.datatypes import SafetyConfig +from llama_stack.core.routers.safety import SafetyRouter + + +async def test_run_moderation_uses_default_shield_when_model_missing(): + routing_table = AsyncMock() + shield = Shield( + identifier="shield-1", + provider_resource_id="provider/shield-model", + provider_id="provider-id", + params={}, + ) + routing_table.list_shields.return_value = ListShieldsResponse(data=[shield]) + + moderation_response = ModerationObject( + id="mid", + model="shield-1", + results=[ModerationObjectResults(flagged=False)], + ) + provider = AsyncMock() + provider.run_moderation.return_value = moderation_response + routing_table.get_provider_impl.return_value = provider + + router = SafetyRouter(routing_table=routing_table, safety_config=SafetyConfig(default_shield_id="shield-1")) + + result = await router.run_moderation("hello world") + + assert result is moderation_response + routing_table.get_provider_impl.assert_awaited_once_with("shield-1") + provider.run_moderation.assert_awaited_once() + _, kwargs = provider.run_moderation.call_args + assert kwargs["model"] == "provider/shield-model" + assert kwargs["input"] == "hello world" diff --git a/tests/unit/core/test_stack_validation.py b/tests/unit/core/test_stack_validation.py index fa5348d1c..d28803006 100644 --- a/tests/unit/core/test_stack_validation.py +++ b/tests/unit/core/test_stack_validation.py @@ -11,8 +11,9 @@ from unittest.mock import AsyncMock import pytest from llama_stack.apis.models import ListModelsResponse, Model, ModelType -from llama_stack.core.datatypes import QualifiedModel, StackRunConfig, StorageConfig, VectorStoresConfig -from llama_stack.core.stack import validate_vector_stores_config +from llama_stack.apis.shields import ListShieldsResponse, Shield +from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackRunConfig, StorageConfig, VectorStoresConfig +from llama_stack.core.stack import validate_safety_config, validate_vector_stores_config from llama_stack.providers.datatypes import Api @@ -65,3 +66,37 @@ class TestVectorStoresValidation: ) await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models}) + + +class TestSafetyConfigValidation: + async def test_validate_success(self): + safety_config = SafetyConfig(default_shield_id="shield-1") + + shield = Shield( + identifier="shield-1", + provider_id="provider-x", + provider_resource_id="model-x", + params={}, + ) + + shields_impl = AsyncMock() + shields_impl.list_shields.return_value = ListShieldsResponse(data=[shield]) + + await validate_safety_config(safety_config, {Api.shields: shields_impl, Api.safety: AsyncMock()}) + + async def test_validate_wrong_shield_id(self): + safety_config = SafetyConfig(default_shield_id="wrong-shield-id") + + shields_impl = AsyncMock() + shields_impl.list_shields.return_value = ListShieldsResponse( + data=[ + Shield( + identifier="shield-1", + provider_resource_id="model-x", + provider_id="provider-x", + params={}, + ) + ] + ) + with pytest.raises(ValueError, match="wrong-shield-id"): + await validate_safety_config(safety_config, {Api.shields: shields_impl, Api.safety: AsyncMock()})