llama-stack-mirror/toolchain/safety/shields/contrib/third_party_shield.py
2024-07-19 12:30:35 -07:00

28 lines
802 B
Python

import sys
from typing import List
from models.llama3.datatypes import Message
parent_dir = "../.."
sys.path.append(parent_dir)
from toolchain.safety.shields.base import OnViolationAction, ShieldBase, ShieldResponse
_INSTANCE = None
class ThirdPartyShield(ShieldBase):
@staticmethod
def instance(on_violation_action=OnViolationAction.RAISE) -> "ThirdPartyShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = ThirdPartyShield(on_violation_action)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
async def run(self, messages: List[Message]) -> ShieldResponse:
super.run() # will raise NotImplementedError