mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 10:10:36 +00:00
chore: support default model in moderations API
# What does this PR do? ## Test Plan
This commit is contained in:
parent
7b90e0e9c8
commit
3381874de9
23 changed files with 211 additions and 36 deletions
|
|
@ -374,6 +374,15 @@ class VectorStoresConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class SafetyConfig(BaseModel):
|
||||
"""Configuration for default moderations model."""
|
||||
|
||||
default_shield_id: str | None = Field(
|
||||
default=None,
|
||||
description="ID of the shield to use for when `model` is not specified in the `moderations` API request.",
|
||||
)
|
||||
|
||||
|
||||
class QuotaPeriod(StrEnum):
|
||||
DAY = "day"
|
||||
|
||||
|
|
@ -532,6 +541,11 @@ can be instantiated multiple times (with different configs) if necessary.
|
|||
description="Configuration for vector stores, including default embedding model",
|
||||
)
|
||||
|
||||
safety: SafetyConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for default moderations model",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
|
|
|
|||
|
|
@ -95,6 +95,8 @@ async def get_auto_router_impl(
|
|||
|
||||
elif api == Api.vector_io:
|
||||
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||
elif api == Api.safety:
|
||||
api_to_dep_impl["safety_config"] = run_config.safety
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
await impl.initialize()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from llama_stack.apis.inference import Message
|
|||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import SafetyConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
|
|
@ -20,9 +21,11 @@ class SafetyRouter(Safety):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
safety_config: SafetyConfig | None = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing SafetyRouter")
|
||||
self.routing_table = routing_table
|
||||
self.safety_config = safety_config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("SafetyRouter.initialize")
|
||||
|
|
@ -60,30 +63,47 @@ class SafetyRouter(Safety):
|
|||
params=params,
|
||||
)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
async def get_shield_id(self, model: str) -> str:
|
||||
"""Get Shield id from model (provider_resource_id) of shield."""
|
||||
list_shields_response = await self.routing_table.list_shields()
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
list_shields_response = await self.routing_table.list_shields()
|
||||
shields = list_shields_response.data
|
||||
|
||||
matches: list[str] = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
|
||||
selected_shield: Shield | None = None
|
||||
provider_model: str | None = model
|
||||
|
||||
if model:
|
||||
matches: list[Shield] = [s for s in shields if model == s.provider_resource_id]
|
||||
if not matches:
|
||||
raise ValueError(
|
||||
f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in list_shields_response.data]}"
|
||||
f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in shields]}"
|
||||
)
|
||||
if len(matches) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple shields associated with provider_resource id {model}: matched shields {matches}"
|
||||
f"Multiple shields associated with provider_resource id {model}: matched shields {[s.identifier for s in matches]}"
|
||||
)
|
||||
selected_shield = matches[0]
|
||||
else:
|
||||
default_shield_id = self.safety_config.default_shield_id if self.safety_config else None
|
||||
if not default_shield_id:
|
||||
raise ValueError(
|
||||
"No moderation model specified and no default_shield_id configured in safety config: select model "
|
||||
f"from {[s.provider_resource_id or s.identifier for s in shields]}"
|
||||
)
|
||||
return matches[0]
|
||||
|
||||
shield_id = await get_shield_id(self, model)
|
||||
selected_shield = next((s for s in shields if s.identifier == default_shield_id), None)
|
||||
if selected_shield is None:
|
||||
raise ValueError(
|
||||
f"Configured default_shield_id '{default_shield_id}' not found. Available shields: {[s.identifier for s in shields]}"
|
||||
)
|
||||
|
||||
provider_model = selected_shield.provider_resource_id or selected_shield.identifier
|
||||
|
||||
shield_id = selected_shield.identifier
|
||||
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
|
||||
response = await provider.run_moderation(
|
||||
input=input,
|
||||
model=model,
|
||||
model=provider_model,
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
|
|||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig
|
||||
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||
|
|
@ -175,6 +175,29 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
|
|||
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
|
||||
|
||||
|
||||
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
|
||||
if safety_config is None or safety_config.default_shield_id is None:
|
||||
return
|
||||
|
||||
if Api.shields not in impls:
|
||||
raise ValueError("Safety configuration requires the shields API to be enabled")
|
||||
|
||||
if Api.safety not in impls:
|
||||
raise ValueError("Safety configuration requires the safety API to be enabled")
|
||||
|
||||
shields_impl = impls[Api.shields]
|
||||
response = await shields_impl.list_shields()
|
||||
shields_by_id = {shield.identifier: shield for shield in response.data}
|
||||
|
||||
default_shield_id = safety_config.default_shield_id
|
||||
if default_shield_id not in shields_by_id:
|
||||
available = sorted(shields_by_id)
|
||||
raise ValueError(
|
||||
f"Configured default_shield_id '{default_shield_id}' not found among registered shields."
|
||||
f" Available shields: {available}"
|
||||
)
|
||||
|
||||
|
||||
class EnvVarError(Exception):
|
||||
def __init__(self, var_name: str, path: str = ""):
|
||||
self.var_name = var_name
|
||||
|
|
@ -412,6 +435,7 @@ class Stack:
|
|||
await register_resources(self.run_config, impls)
|
||||
await refresh_registry_once(impls)
|
||||
await validate_vector_stores_config(self.run_config.vector_stores, impls)
|
||||
await validate_safety_config(self.run_config.safety, impls)
|
||||
self.impls = impls
|
||||
|
||||
def create_registry_refresh_task(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue