From c253c1c9ad093e5f72ff978409afeb985090395b Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 31 Jul 2024 21:57:10 -0700 Subject: [PATCH] Begin adding a /safety/run_shield API --- llama_toolchain/safety/api/datatypes.py | 7 ------- llama_toolchain/safety/api/endpoints.py | 28 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 7 deletions(-) create mode 100644 llama_toolchain/safety/api/endpoints.py diff --git a/llama_toolchain/safety/api/datatypes.py b/llama_toolchain/safety/api/datatypes.py index 6e6d705c3..2ca00067a 100644 --- a/llama_toolchain/safety/api/datatypes.py +++ b/llama_toolchain/safety/api/datatypes.py @@ -44,13 +44,6 @@ class ShieldDefinition(BaseModel): execution_config: Optional[RestAPIExecutionConfig] = None -@json_schema_type -class ShieldCall(BaseModel): - call_id: str - shield_type: ShieldType - arguments: Dict[str, str] - - @json_schema_type class ShieldResponse(BaseModel): shield_type: ShieldType diff --git a/llama_toolchain/safety/api/endpoints.py b/llama_toolchain/safety/api/endpoints.py new file mode 100644 index 000000000..8558ed8fd --- /dev/null +++ b/llama_toolchain/safety/api/endpoints.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .datatypes import * # noqa: F403 +from typing import Protocol + +from llama_models.llama3_1.api.datatypes import Message + +# this dependency is annoying and we need a forked up version anyway +from pyopenapi import webmethod + + +@json_schema_type +class RunShieldRequest(BaseModel): + shield_type: ShieldType + messages: List[Message] + + +class SafetyCheck(Protocol): + + @webmethod(route="/safety/run_shield") + async def run_shield( + self, + request: RunShieldRequest, + ) -> ShieldResponse: ...