diff --git a/llama_toolchain/safety/api/datatypes.py b/llama_toolchain/safety/api/datatypes.py index 6e6d705c3..2ca00067a 100644 --- a/llama_toolchain/safety/api/datatypes.py +++ b/llama_toolchain/safety/api/datatypes.py @@ -44,13 +44,6 @@ class ShieldDefinition(BaseModel): execution_config: Optional[RestAPIExecutionConfig] = None -@json_schema_type -class ShieldCall(BaseModel): - call_id: str - shield_type: ShieldType - arguments: Dict[str, str] - - @json_schema_type class ShieldResponse(BaseModel): shield_type: ShieldType diff --git a/llama_toolchain/safety/api/endpoints.py b/llama_toolchain/safety/api/endpoints.py new file mode 100644 index 000000000..8558ed8fd --- /dev/null +++ b/llama_toolchain/safety/api/endpoints.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .datatypes import * # noqa: F403 +from typing import Protocol + +from llama_models.llama3_1.api.datatypes import Message + +# this dependency is annoying and we need a forked up version anyway +from pyopenapi import webmethod + + +@json_schema_type +class RunShieldRequest(BaseModel): + shield_type: ShieldType + messages: List[Message] + + +class SafetyCheck(Protocol): + + @webmethod(route="/safety/run_shield") + async def run_shield( + self, + request: RunShieldRequest, + ) -> ShieldResponse: ...