diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 84560b355..f80129fa1 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -54,11 +54,11 @@ async def get_auto_router_impl( DatasetIORouter, EvalRouter, InferenceRouter, - SafetyRouter, ScoringRouter, ToolRuntimeRouter, VectorIORouter, ) + from .safety import SafetyRouter api_to_routers = { "vector_io": VectorIORouter, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 0515b19f8..66ba7837c 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -54,14 +54,12 @@ from llama_stack.apis.inference.inference import ( OpenAIResponseFormatParam, ) from llama_stack.apis.models import Model, ModelType -from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( ScoreBatchResponse, ScoreResponse, Scoring, ScoringFnParams, ) -from llama_stack.apis.shields import Shield from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.tools import ( ListToolDefsResponse, @@ -673,46 +671,6 @@ class InferenceRouter(Inference): return health_statuses -class SafetyRouter(Safety): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing SafetyRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("SafetyRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("SafetyRouter.shutdown") - pass - - async def register_shield( - self, - shield_id: str, - provider_shield_id: str | None = None, - provider_id: str | None = None, - params: dict[str, Any] | None = None, - ) -> Shield: - logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) - - async def run_shield( - self, - shield_id: str, - messages: list[Message], - params: dict[str, Any] = None, - ) -> RunShieldResponse: - logger.debug(f"SafetyRouter.run_shield: {shield_id}") - return await self.routing_table.get_provider_impl(shield_id).run_shield( - shield_id=shield_id, - messages=messages, - params=params, - ) - - class DatasetIORouter(DatasetIO): def __init__( self, diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py new file mode 100644 index 000000000..9761d2db0 --- /dev/null +++ b/llama_stack/distribution/routers/safety.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.inference import ( + Message, +) +from llama_stack.apis.safety import RunShieldResponse, Safety +from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class SafetyRouter(Safety): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing SafetyRouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("SafetyRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("SafetyRouter.shutdown") + pass + + async def register_shield( + self, + shield_id: str, + provider_shield_id: str | None = None, + provider_id: str | None = None, + params: dict[str, Any] | None = None, + ) -> Shield: + logger.debug(f"SafetyRouter.register_shield: {shield_id}") + return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + + async def run_shield( + self, + shield_id: str, + messages: list[Message], + params: dict[str, Any] = None, + ) -> RunShieldResponse: + logger.debug(f"SafetyRouter.run_shield: {shield_id}") + return await self.routing_table.get_provider_impl(shield_id).run_shield( + shield_id=shield_id, + messages=messages, + params=params, + )