chore: split routers into individual files (safety)

Reviewers:
bbrowning, leseb, ehhuang, terrytangyuan, raghotham, yanxi0830, hardikjshah

Reviewed By: raghotham

Pull Request: https://github.com/meta-llama/llama-stack/pull/2248
This commit is contained in:
Ashwin Bharambe 2025-05-24 22:00:32 -07:00 committed by GitHub
parent c290999c63
commit a2160dc0af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 58 additions and 43 deletions

View file

@ -54,11 +54,11 @@ async def get_auto_router_impl(
DatasetIORouter, DatasetIORouter,
EvalRouter, EvalRouter,
InferenceRouter, InferenceRouter,
SafetyRouter,
ScoringRouter, ScoringRouter,
ToolRuntimeRouter, ToolRuntimeRouter,
VectorIORouter, VectorIORouter,
) )
from .safety import SafetyRouter
api_to_routers = { api_to_routers = {
"vector_io": VectorIORouter, "vector_io": VectorIORouter,

View file

@ -54,14 +54,12 @@ from llama_stack.apis.inference.inference import (
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import ( from llama_stack.apis.scoring import (
ScoreBatchResponse, ScoreBatchResponse,
ScoreResponse, ScoreResponse,
Scoring, Scoring,
ScoringFnParams, ScoringFnParams,
) )
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
@ -673,46 +671,6 @@ class InferenceRouter(Inference):
return health_statuses return health_statuses
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
pass
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
params=params,
)
class DatasetIORouter(DatasetIO): class DatasetIORouter(DatasetIO):
def __init__( def __init__(
self, self,

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.inference import (
Message,
)
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
pass
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
params=params,
)