fix(lakera_ai.py): fix hardcoded prompt_injection string in lakera integration

This commit is contained in:
Krrish Dholakia 2024-08-06 16:12:54 -07:00
parent 2a95484a83
commit 907d83e5d9
2 changed files with 19 additions and 9 deletions

View file

@ -10,7 +10,7 @@ import sys, os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from typing import Literal, List, Dict from typing import Literal, List, Dict, Optional
import litellm, sys import litellm, sys
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -64,11 +64,24 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
return return
text = "" text = ""
if "messages" in data and isinstance(data["messages"], list): if "messages" in data and isinstance(data["messages"], list):
enabled_roles = litellm.guardrail_name_config_map[ prompt_injection_obj: Optional[GuardrailItem] = (
"prompt_injection" litellm.guardrail_name_config_map.get("prompt_injection")
].enabled_roles )
if prompt_injection_obj is not None:
enabled_roles = prompt_injection_obj.enabled_roles
else:
enabled_roles = None
if enabled_roles is None: if enabled_roles is None:
enabled_roles = default_roles enabled_roles = default_roles
stringified_roles: List[str] = []
if enabled_roles is not None: # convert to list of str
for role in enabled_roles:
if isinstance(role, Role):
stringified_roles.append(role.value)
elif isinstance(role, str):
stringified_roles.append(role)
lakera_input_dict: Dict = { lakera_input_dict: Dict = {
role: None for role in INPUT_POSITIONING_MAP.keys() role: None for role in INPUT_POSITIONING_MAP.keys()
} }
@ -76,7 +89,7 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
tool_call_messages: List = [] tool_call_messages: List = []
for message in data["messages"]: for message in data["messages"]:
role = message.get("role") role = message.get("role")
if role in enabled_roles: if role in stringified_roles:
if "tool_calls" in message: if "tool_calls" in message:
tool_call_messages = [ tool_call_messages = [
*tool_call_messages, *tool_call_messages,

View file

@ -4,7 +4,4 @@ model_list:
model: "*" model: "*"
litellm_settings: litellm_settings:
guardrails: callbacks: ["lakera_prompt_injection"]
- prompt_injection: # your custom name for guardrail
callbacks: [lakera_prompt_injection] # litellm callbacks to use
default_on: true # will run on all llm requests when true