mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 10:10:36 +00:00
chore: support default model in moderations API
# What does this PR do? ## Test Plan
This commit is contained in:
parent
7b90e0e9c8
commit
4f7fedd91d
23 changed files with 189 additions and 36 deletions
|
|
@ -6443,11 +6443,10 @@ components:
|
||||||
model:
|
model:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
The content moderation model you would like to use.
|
(Optional) The content moderation model you would like to use.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- input
|
- input
|
||||||
- model
|
|
||||||
title: RunModerationRequest
|
title: RunModerationRequest
|
||||||
ModerationObject:
|
ModerationObject:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
||||||
5
docs/static/deprecated-llama-stack-spec.html
vendored
5
docs/static/deprecated-llama-stack-spec.html
vendored
|
|
@ -8185,13 +8185,12 @@
|
||||||
},
|
},
|
||||||
"model": {
|
"model": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The content moderation model you would like to use."
|
"description": "(Optional) The content moderation model you would like to use."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"input",
|
"input"
|
||||||
"model"
|
|
||||||
],
|
],
|
||||||
"title": "RunModerationRequest"
|
"title": "RunModerationRequest"
|
||||||
},
|
},
|
||||||
|
|
|
||||||
3
docs/static/deprecated-llama-stack-spec.yaml
vendored
3
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
|
@ -6104,11 +6104,10 @@ components:
|
||||||
model:
|
model:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
The content moderation model you would like to use.
|
(Optional) The content moderation model you would like to use.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- input
|
- input
|
||||||
- model
|
|
||||||
title: RunModerationRequest
|
title: RunModerationRequest
|
||||||
ModerationObject:
|
ModerationObject:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
||||||
5
docs/static/llama-stack-spec.html
vendored
5
docs/static/llama-stack-spec.html
vendored
|
|
@ -6919,13 +6919,12 @@
|
||||||
},
|
},
|
||||||
"model": {
|
"model": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The content moderation model you would like to use."
|
"description": "(Optional) The content moderation model you would like to use."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"input",
|
"input"
|
||||||
"model"
|
|
||||||
],
|
],
|
||||||
"title": "RunModerationRequest"
|
"title": "RunModerationRequest"
|
||||||
},
|
},
|
||||||
|
|
|
||||||
3
docs/static/llama-stack-spec.yaml
vendored
3
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -5230,11 +5230,10 @@ components:
|
||||||
model:
|
model:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
The content moderation model you would like to use.
|
(Optional) The content moderation model you would like to use.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- input
|
- input
|
||||||
- model
|
|
||||||
title: RunModerationRequest
|
title: RunModerationRequest
|
||||||
ModerationObject:
|
ModerationObject:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
||||||
5
docs/static/stainless-llama-stack-spec.html
vendored
5
docs/static/stainless-llama-stack-spec.html
vendored
|
|
@ -8591,13 +8591,12 @@
|
||||||
},
|
},
|
||||||
"model": {
|
"model": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The content moderation model you would like to use."
|
"description": "(Optional) The content moderation model you would like to use."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"input",
|
"input"
|
||||||
"model"
|
|
||||||
],
|
],
|
||||||
"title": "RunModerationRequest"
|
"title": "RunModerationRequest"
|
||||||
},
|
},
|
||||||
|
|
|
||||||
3
docs/static/stainless-llama-stack-spec.yaml
vendored
3
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -6443,11 +6443,10 @@ components:
|
||||||
model:
|
model:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
The content moderation model you would like to use.
|
(Optional) The content moderation model you would like to use.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- input
|
- input
|
||||||
- model
|
|
||||||
title: RunModerationRequest
|
title: RunModerationRequest
|
||||||
ModerationObject:
|
ModerationObject:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
||||||
|
|
@ -123,13 +123,13 @@ class Safety(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||||
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||||
"""Create moderation.
|
"""Create moderation.
|
||||||
|
|
||||||
Classifies if text and/or image inputs are potentially harmful.
|
Classifies if text and/or image inputs are potentially harmful.
|
||||||
:param input: Input (or inputs) to classify.
|
:param input: Input (or inputs) to classify.
|
||||||
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
||||||
:param model: The content moderation model you would like to use.
|
:param model: (Optional) The content moderation model you would like to use.
|
||||||
:returns: A moderation object.
|
:returns: A moderation object.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
||||||
|
|
@ -374,6 +374,15 @@ class VectorStoresConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyConfig(BaseModel):
|
||||||
|
"""Configuration for default moderations model."""
|
||||||
|
|
||||||
|
default_shield_id: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="ID of the shield to use for when `model` is not specified in the `moderations` API request.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QuotaPeriod(StrEnum):
|
class QuotaPeriod(StrEnum):
|
||||||
DAY = "day"
|
DAY = "day"
|
||||||
|
|
||||||
|
|
@ -532,6 +541,11 @@ can be instantiated multiple times (with different configs) if necessary.
|
||||||
description="Configuration for vector stores, including default embedding model",
|
description="Configuration for vector stores, including default embedding model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
safety: SafetyConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Configuration for default moderations model",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("external_providers_dir")
|
@field_validator("external_providers_dir")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_external_providers_dir(cls, v):
|
def validate_external_providers_dir(cls, v):
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,8 @@ async def get_auto_router_impl(
|
||||||
|
|
||||||
elif api == Api.vector_io:
|
elif api == Api.vector_io:
|
||||||
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||||
|
elif api == Api.safety:
|
||||||
|
api_to_dep_impl["safety_config"] = run_config.safety
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.safety.safety import ModerationObject
|
from llama_stack.apis.safety.safety import ModerationObject
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.core.datatypes import SafetyConfig
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
|
@ -20,9 +21,11 @@ class SafetyRouter(Safety):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
|
safety_config: SafetyConfig | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing SafetyRouter")
|
logger.debug("Initializing SafetyRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
self.safety_config = safety_config
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug("SafetyRouter.initialize")
|
logger.debug("SafetyRouter.initialize")
|
||||||
|
|
@ -60,30 +63,47 @@ class SafetyRouter(Safety):
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||||
async def get_shield_id(self, model: str) -> str:
|
list_shields_response = await self.routing_table.list_shields()
|
||||||
"""Get Shield id from model (provider_resource_id) of shield."""
|
shields = list_shields_response.data
|
||||||
list_shields_response = await self.routing_table.list_shields()
|
|
||||||
|
|
||||||
matches: list[str] = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
|
selected_shield: Shield | None = None
|
||||||
|
provider_model: str | None = model
|
||||||
|
|
||||||
|
if model:
|
||||||
|
matches: list[Shield] = [s for s in shields if model == s.provider_resource_id]
|
||||||
if not matches:
|
if not matches:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in list_shields_response.data]}"
|
f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in shields]}"
|
||||||
)
|
)
|
||||||
if len(matches) > 1:
|
if len(matches) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Multiple shields associated with provider_resource id {model}: matched shields {matches}"
|
f"Multiple shields associated with provider_resource id {model}: matched shields {[s.identifier for s in matches]}"
|
||||||
|
)
|
||||||
|
selected_shield = matches[0]
|
||||||
|
else:
|
||||||
|
default_shield_id = self.safety_config.default_shield_id if self.safety_config else None
|
||||||
|
if not default_shield_id:
|
||||||
|
raise ValueError(
|
||||||
|
"No moderation model specified and no default_shield_id configured in safety config: select model "
|
||||||
|
f"from {[s.provider_resource_id or s.identifier for s in shields]}"
|
||||||
)
|
)
|
||||||
return matches[0]
|
|
||||||
|
|
||||||
shield_id = await get_shield_id(self, model)
|
selected_shield = next((s for s in shields if s.identifier == default_shield_id), None)
|
||||||
|
if selected_shield is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Default moderation model not found. Choose from {[s.provider_resource_id or s.identifier for s in shields]}."
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_model = selected_shield.provider_resource_id or selected_shield.identifier
|
||||||
|
|
||||||
|
shield_id = selected_shield.identifier
|
||||||
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
|
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||||
|
|
||||||
response = await provider.run_moderation(
|
response = await provider.run_moderation(
|
||||||
input=input,
|
input=input,
|
||||||
model=model,
|
model=provider_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||||
from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig
|
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
|
||||||
from llama_stack.core.distribution import get_provider_registry
|
from llama_stack.core.distribution import get_provider_registry
|
||||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||||
|
|
@ -175,6 +175,30 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
|
||||||
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
|
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
|
||||||
|
if safety_config is None or safety_config.default_shield_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if Api.shields not in impls:
|
||||||
|
raise ValueError("Safety configuration requires the shields API to be enabled")
|
||||||
|
|
||||||
|
if Api.safety not in impls:
|
||||||
|
raise ValueError("Safety configuration requires the safety API to be enabled")
|
||||||
|
|
||||||
|
shields_impl = impls[Api.shields]
|
||||||
|
response = await shields_impl.list_shields()
|
||||||
|
shields_by_id = {shield.identifier: shield for shield in response.data}
|
||||||
|
|
||||||
|
default_shield_id = safety_config.default_shield_id
|
||||||
|
# don't validate if there are no shields registered
|
||||||
|
if shields_by_id and default_shield_id not in shields_by_id:
|
||||||
|
available = sorted(shields_by_id)
|
||||||
|
raise ValueError(
|
||||||
|
f"Configured default_shield_id '{default_shield_id}' not found among registered shields."
|
||||||
|
f" Available shields: {available}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EnvVarError(Exception):
|
class EnvVarError(Exception):
|
||||||
def __init__(self, var_name: str, path: str = ""):
|
def __init__(self, var_name: str, path: str = ""):
|
||||||
self.var_name = var_name
|
self.var_name = var_name
|
||||||
|
|
@ -412,6 +436,7 @@ class Stack:
|
||||||
await register_resources(self.run_config, impls)
|
await register_resources(self.run_config, impls)
|
||||||
await refresh_registry_once(impls)
|
await refresh_registry_once(impls)
|
||||||
await validate_vector_stores_config(self.run_config.vector_stores, impls)
|
await validate_vector_stores_config(self.run_config.vector_stores, impls)
|
||||||
|
await validate_safety_config(self.run_config.safety, impls)
|
||||||
self.impls = impls
|
self.impls = impls
|
||||||
|
|
||||||
def create_registry_refresh_task(self):
|
def create_registry_refresh_task(self):
|
||||||
|
|
|
||||||
|
|
@ -274,3 +274,5 @@ vector_stores:
|
||||||
default_embedding_model:
|
default_embedding_model:
|
||||||
provider_id: sentence-transformers
|
provider_id: sentence-transformers
|
||||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||||
|
safety:
|
||||||
|
default_shield_id: llama-guard
|
||||||
|
|
|
||||||
|
|
@ -277,3 +277,5 @@ vector_stores:
|
||||||
default_embedding_model:
|
default_embedding_model:
|
||||||
provider_id: sentence-transformers
|
provider_id: sentence-transformers
|
||||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||||
|
safety:
|
||||||
|
default_shield_id: llama-guard
|
||||||
|
|
|
||||||
|
|
@ -274,3 +274,5 @@ vector_stores:
|
||||||
default_embedding_model:
|
default_embedding_model:
|
||||||
provider_id: sentence-transformers
|
provider_id: sentence-transformers
|
||||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||||
|
safety:
|
||||||
|
default_shield_id: llama-guard
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from llama_stack.core.datatypes import (
|
||||||
Provider,
|
Provider,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
QualifiedModel,
|
QualifiedModel,
|
||||||
|
SafetyConfig,
|
||||||
ShieldInput,
|
ShieldInput,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
VectorStoresConfig,
|
VectorStoresConfig,
|
||||||
|
|
@ -256,6 +257,9 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
||||||
model_id="nomic-ai/nomic-embed-text-v1.5",
|
model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
safety_config=SafetyConfig(
|
||||||
|
default_shield_id="llama-guard",
|
||||||
|
),
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
run_config_env_vars={
|
run_config_env_vars={
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from llama_stack.core.datatypes import (
|
||||||
DistributionSpec,
|
DistributionSpec,
|
||||||
ModelInput,
|
ModelInput,
|
||||||
Provider,
|
Provider,
|
||||||
|
SafetyConfig,
|
||||||
ShieldInput,
|
ShieldInput,
|
||||||
TelemetryConfig,
|
TelemetryConfig,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
|
|
@ -188,6 +189,7 @@ class RunConfigSettings(BaseModel):
|
||||||
default_datasets: list[DatasetInput] | None = None
|
default_datasets: list[DatasetInput] | None = None
|
||||||
default_benchmarks: list[BenchmarkInput] | None = None
|
default_benchmarks: list[BenchmarkInput] | None = None
|
||||||
vector_stores_config: VectorStoresConfig | None = None
|
vector_stores_config: VectorStoresConfig | None = None
|
||||||
|
safety_config: SafetyConfig | None = None
|
||||||
telemetry: TelemetryConfig = Field(default_factory=lambda: TelemetryConfig(enabled=True))
|
telemetry: TelemetryConfig = Field(default_factory=lambda: TelemetryConfig(enabled=True))
|
||||||
storage_backends: dict[str, Any] | None = None
|
storage_backends: dict[str, Any] | None = None
|
||||||
storage_stores: dict[str, Any] | None = None
|
storage_stores: dict[str, Any] | None = None
|
||||||
|
|
@ -290,6 +292,9 @@ class RunConfigSettings(BaseModel):
|
||||||
if self.vector_stores_config:
|
if self.vector_stores_config:
|
||||||
config["vector_stores"] = self.vector_stores_config.model_dump(exclude_none=True)
|
config["vector_stores"] = self.vector_stores_config.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
if self.safety_config:
|
||||||
|
config["safety"] = self.safety_config.model_dump(exclude_none=True)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,10 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||||
|
if model is None:
|
||||||
|
raise ValueError("Code scanner moderation requires a model identifier.")
|
||||||
|
|
||||||
inputs = input if isinstance(input, list) else [input]
|
inputs = input if isinstance(input, list) else [input]
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -200,7 +200,10 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
|
|
||||||
return await impl.run(messages)
|
return await impl.run(messages)
|
||||||
|
|
||||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||||
|
if model is None:
|
||||||
|
raise ValueError("Llama Guard moderation requires a model identifier.")
|
||||||
|
|
||||||
if isinstance(input, list):
|
if isinstance(input, list):
|
||||||
messages = input.copy()
|
messages = input.copy()
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
|
|
||||||
return await self.shield.run(messages)
|
return await self.shield.run(messages)
|
||||||
|
|
||||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||||
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
|
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
||||||
return await self.shield.run(messages)
|
return await self.shield.run(messages)
|
||||||
|
|
||||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||||
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
|
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
43
tests/unit/core/routers/test_safety_router.py
Normal file
43
tests/unit/core/routers/test_safety_router.py
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||||
|
from llama_stack.apis.shields import ListShieldsResponse, Shield
|
||||||
|
from llama_stack.core.datatypes import SafetyConfig
|
||||||
|
from llama_stack.core.routers.safety import SafetyRouter
|
||||||
|
|
||||||
|
|
||||||
|
async def test_run_moderation_uses_default_shield_when_model_missing():
|
||||||
|
routing_table = AsyncMock()
|
||||||
|
shield = Shield(
|
||||||
|
identifier="shield-1",
|
||||||
|
provider_resource_id="provider/shield-model",
|
||||||
|
provider_id="provider-id",
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
routing_table.list_shields.return_value = ListShieldsResponse(data=[shield])
|
||||||
|
|
||||||
|
moderation_response = ModerationObject(
|
||||||
|
id="mid",
|
||||||
|
model="shield-1",
|
||||||
|
results=[ModerationObjectResults(flagged=False)],
|
||||||
|
)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.run_moderation.return_value = moderation_response
|
||||||
|
routing_table.get_provider_impl.return_value = provider
|
||||||
|
|
||||||
|
router = SafetyRouter(routing_table=routing_table, safety_config=SafetyConfig(default_shield_id="shield-1"))
|
||||||
|
|
||||||
|
result = await router.run_moderation("hello world")
|
||||||
|
|
||||||
|
assert result is moderation_response
|
||||||
|
routing_table.get_provider_impl.assert_awaited_once_with("shield-1")
|
||||||
|
provider.run_moderation.assert_awaited_once()
|
||||||
|
_, kwargs = provider.run_moderation.call_args
|
||||||
|
assert kwargs["model"] == "provider/shield-model"
|
||||||
|
assert kwargs["input"] == "hello world"
|
||||||
|
|
@ -11,8 +11,9 @@ from unittest.mock import AsyncMock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, ModelType
|
from llama_stack.apis.models import ListModelsResponse, Model, ModelType
|
||||||
from llama_stack.core.datatypes import QualifiedModel, StackRunConfig, StorageConfig, VectorStoresConfig
|
from llama_stack.apis.shields import ListShieldsResponse, Shield
|
||||||
from llama_stack.core.stack import validate_vector_stores_config
|
from llama_stack.core.datatypes import QualifiedModel, SafetyConfig, StackRunConfig, StorageConfig, VectorStoresConfig
|
||||||
|
from llama_stack.core.stack import validate_safety_config, validate_vector_stores_config
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -65,3 +66,37 @@ class TestVectorStoresValidation:
|
||||||
)
|
)
|
||||||
|
|
||||||
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
|
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafetyConfigValidation:
|
||||||
|
async def test_validate_success(self):
|
||||||
|
safety_config = SafetyConfig(default_shield_id="shield-1")
|
||||||
|
|
||||||
|
shield = Shield(
|
||||||
|
identifier="shield-1",
|
||||||
|
provider_id="provider-x",
|
||||||
|
provider_resource_id="model-x",
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
|
||||||
|
shields_impl = AsyncMock()
|
||||||
|
shields_impl.list_shields.return_value = ListShieldsResponse(data=[shield])
|
||||||
|
|
||||||
|
await validate_safety_config(safety_config, {Api.shields: shields_impl, Api.safety: AsyncMock()})
|
||||||
|
|
||||||
|
async def test_validate_wrong_shield_id(self):
|
||||||
|
safety_config = SafetyConfig(default_shield_id="wrong-shield-id")
|
||||||
|
|
||||||
|
shields_impl = AsyncMock()
|
||||||
|
shields_impl.list_shields.return_value = ListShieldsResponse(
|
||||||
|
data=[
|
||||||
|
Shield(
|
||||||
|
identifier="shield-1",
|
||||||
|
provider_resource_id="model-x",
|
||||||
|
provider_id="provider-x",
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="wrong-shield-id"):
|
||||||
|
await validate_safety_config(safety_config, {Api.shields: shields_impl, Api.safety: AsyncMock()})
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue