fix linting errors on main

This commit is contained in:
Ishaan Jaff 2024-07-18 13:32:48 -07:00
parent df4aab8be9
commit 75ca53fab5
3 changed files with 64 additions and 29 deletions

View file

@ -10,7 +10,7 @@ import sys, os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import Literal
from typing import Literal, List, Dict
import litellm, sys
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
@ -18,7 +18,7 @@ from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.types.guardrails import Role
from litellm.types.guardrails import Role, GuardrailItem, default_roles
from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
@ -33,9 +33,10 @@ GUARDRAIL_NAME = "lakera_prompt_injection"
INPUT_POSITIONING_MAP = {
Role.SYSTEM.value: 0,
Role.USER.value: 1,
Role.ASSISTANT.value: 2
Role.ASSISTANT.value: 2,
}
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
def __init__(self):
self.async_handler = AsyncHTTPHandler(
@ -63,41 +64,63 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
return
text = ""
if "messages" in data and isinstance(data["messages"], list):
enabled_roles = litellm.guardrail_name_config_map["prompt_injection"].enabled_roles
lakera_input_dict = {role: None for role in INPUT_POSITIONING_MAP.keys()}
system_message = None
tool_call_messages = []
enabled_roles = litellm.guardrail_name_config_map[
"prompt_injection"
].enabled_roles
if enabled_roles is None:
enabled_roles = default_roles
lakera_input_dict: Dict = {
role: None for role in INPUT_POSITIONING_MAP.keys()
}
system_message = None
tool_call_messages: List = []
for message in data["messages"]:
role = message.get("role")
if role in enabled_roles:
if "tool_calls" in message:
tool_call_messages = [*tool_call_messages, *message["tool_calls"]]
if role == Role.SYSTEM.value: # we need this for later
tool_call_messages = [
*tool_call_messages,
*message["tool_calls"],
]
if role == Role.SYSTEM.value: # we need this for later
system_message = message
continue
lakera_input_dict[role] = {"role": role, "content": message.get('content')}
lakera_input_dict[role] = {
"role": role,
"content": message.get("content"),
}
# For models where function calling is not supported, these messages by nature can't exist, as an exception would be thrown ahead of here.
# For models where function calling is not supported, these messages by nature can't exist, as an exception would be thrown ahead of here.
# Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already)
# Finally, if the user did not elect to add them to the system message themselves, and they are there, then add them to system so they can be checked.
# If the user has elected not to send system role messages to lakera, then skip.
if system_message is not None:
if not litellm.add_function_to_prompt:
content = system_message.get("content")
function_input = []
function_input = []
for tool_call in tool_call_messages:
if "function" in tool_call:
function_input.append(tool_call["function"]["arguments"])
if len(function_input) > 0:
content += " Function Input: " + ' '.join(function_input)
lakera_input_dict[Role.SYSTEM.value] = {'role': Role.SYSTEM.value, 'content': content}
content += " Function Input: " + " ".join(function_input)
lakera_input_dict[Role.SYSTEM.value] = {
"role": Role.SYSTEM.value,
"content": content,
}
lakera_input = [v for k, v in sorted(lakera_input_dict.items(), key=lambda x: INPUT_POSITIONING_MAP[x[0]]) if v is not None]
lakera_input = [
v
for k, v in sorted(
lakera_input_dict.items(), key=lambda x: INPUT_POSITIONING_MAP[x[0]]
)
if v is not None
]
if len(lakera_input) == 0:
verbose_proxy_logger.debug("Skipping lakera prompt injection, no roles with messages found")
verbose_proxy_logger.debug(
"Skipping lakera prompt injection, no roles with messages found"
)
return
elif "input" in data and isinstance(data["input"], str):