feat(lakera_ai.py): control running prompt injection between pre-call and in_parallel

This commit is contained in:
Krrish Dholakia 2024-07-22 20:04:42 -07:00
parent a32a7af215
commit 99a5436ed5
6 changed files with 211 additions and 37 deletions

View file

@ -10,7 +10,7 @@ import sys, os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import Literal, List, Dict
from typing import Literal, List, Dict, Optional, Union
import litellm, sys
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
@ -38,14 +38,38 @@ INPUT_POSITIONING_MAP = {
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
def __init__(self):
def __init__(
self, moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel"
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
self.lakera_api_key = os.environ["LAKERA_API_KEY"]
self.moderation_check = moderation_check
pass
#### CALL HOOKS - proxy only ####
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: litellm.DualCache,
data: Dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
],
) -> Optional[Union[Exception, str, Dict]]:
if self.moderation_check == "in_parallel":
return None
return await super().async_pre_call_hook(
user_api_key_dict, cache, data, call_type
)
async def async_moderation_hook( ### 👈 KEY CHANGE ###
self,
@ -53,6 +77,8 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
if self.moderation_check == "pre_call":
return
if (
await should_proceed_based_on_metadata(