mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: split routers into individual files (safety)
Reviewers: bbrowning, leseb, ehhuang, terrytangyuan, raghotham, yanxi0830, hardikjshah Reviewed By: raghotham Pull Request: https://github.com/meta-llama/llama-stack/pull/2248
This commit is contained in:
parent
c290999c63
commit
a2160dc0af
3 changed files with 58 additions and 43 deletions
|
@ -54,11 +54,11 @@ async def get_auto_router_impl(
|
|||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
InferenceRouter,
|
||||
SafetyRouter,
|
||||
ScoringRouter,
|
||||
ToolRuntimeRouter,
|
||||
VectorIORouter,
|
||||
)
|
||||
from .safety import SafetyRouter
|
||||
|
||||
api_to_routers = {
|
||||
"vector_io": VectorIORouter,
|
||||
|
|
|
@ -54,14 +54,12 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
Scoring,
|
||||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
|
@ -673,46 +671,6 @@ class InferenceRouter(Inference):
|
|||
return health_statuses
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing SafetyRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("SafetyRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("SafetyRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Shield:
|
||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
||||
|
||||
class DatasetIORouter(DatasetIO):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
57
llama_stack/distribution/routers/safety.py
Normal file
57
llama_stack/distribution/routers/safety.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Message,
|
||||
)
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing SafetyRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("SafetyRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("SafetyRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Shield:
|
||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue