diff --git a/enterprise/enterprise_hooks/lakera_ai.py b/enterprise/enterprise_hooks/lakera_ai.py index 2a4ad418b3..3d874da8de 100644 --- a/enterprise/enterprise_hooks/lakera_ai.py +++ b/enterprise/enterprise_hooks/lakera_ai.py @@ -32,6 +32,8 @@ import json litellm.set_verbose = True +GUARDRAIL_NAME = "lakera_prompt_injection" + class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): def __init__(self): @@ -41,6 +43,19 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): self.lakera_api_key = os.environ["LAKERA_API_KEY"] pass + async def should_proceed(self, data: dict) -> bool: + """ + checks if this guardrail should be applied to this call + """ + if "metadata" in data and isinstance(data["metadata"], dict): + if "guardrails" in data["metadata"]: + # if guardrails passed in metadata -> this is a list of guardrails the user wants to run on the call + if GUARDRAIL_NAME not in data["metadata"]["guardrails"]: + return False + + # in all other cases it should proceed + return True + #### CALL HOOKS - proxy only #### async def async_moderation_hook( ### 👈 KEY CHANGE ### @@ -49,6 +64,10 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): + + if await self.should_proceed(data=data) is False: + return + if "messages" in data and isinstance(data["messages"], list): text = "" for m in data["messages"]: # assume messages is a list