mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: split routers into individual files (safety)
Reviewers: bbrowning, leseb, ehhuang, terrytangyuan, raghotham, yanxi0830, hardikjshah Reviewed By: raghotham Pull Request: https://github.com/meta-llama/llama-stack/pull/2248
This commit is contained in:
parent
c290999c63
commit
a2160dc0af
3 changed files with 58 additions and 43 deletions
|
@ -54,11 +54,11 @@ async def get_auto_router_impl(
|
||||||
DatasetIORouter,
|
DatasetIORouter,
|
||||||
EvalRouter,
|
EvalRouter,
|
||||||
InferenceRouter,
|
InferenceRouter,
|
||||||
SafetyRouter,
|
|
||||||
ScoringRouter,
|
ScoringRouter,
|
||||||
ToolRuntimeRouter,
|
ToolRuntimeRouter,
|
||||||
VectorIORouter,
|
VectorIORouter,
|
||||||
)
|
)
|
||||||
|
from .safety import SafetyRouter
|
||||||
|
|
||||||
api_to_routers = {
|
api_to_routers = {
|
||||||
"vector_io": VectorIORouter,
|
"vector_io": VectorIORouter,
|
||||||
|
|
|
@ -54,14 +54,12 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
|
||||||
from llama_stack.apis.scoring import (
|
from llama_stack.apis.scoring import (
|
||||||
ScoreBatchResponse,
|
ScoreBatchResponse,
|
||||||
ScoreResponse,
|
ScoreResponse,
|
||||||
Scoring,
|
Scoring,
|
||||||
ScoringFnParams,
|
ScoringFnParams,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
|
||||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
|
@ -673,46 +671,6 @@ class InferenceRouter(Inference):
|
||||||
return health_statuses
|
return health_statuses
|
||||||
|
|
||||||
|
|
||||||
class SafetyRouter(Safety):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing SafetyRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("SafetyRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("SafetyRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
provider_shield_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> Shield:
|
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
|
||||||
|
|
||||||
async def run_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
messages: list[Message],
|
|
||||||
params: dict[str, Any] = None,
|
|
||||||
) -> RunShieldResponse:
|
|
||||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
|
||||||
shield_id=shield_id,
|
|
||||||
messages=messages,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetIORouter(DatasetIO):
|
class DatasetIORouter(DatasetIO):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
57
llama_stack/distribution/routers/safety.py
Normal file
57
llama_stack/distribution/routers/safety.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
Message,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyRouter(Safety):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing SafetyRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("SafetyRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("SafetyRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
provider_shield_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> Shield:
|
||||||
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
messages: list[Message],
|
||||||
|
params: dict[str, Any] = None,
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||||
|
shield_id=shield_id,
|
||||||
|
messages=messages,
|
||||||
|
params=params,
|
||||||
|
)
|
Loading…
Add table
Add a link
Reference in a new issue