feat(lakera_ai.py): control running prompt injection between pre-call and in_parallel

This commit is contained in:
Krrish Dholakia 2024-07-22 20:04:42 -07:00
parent a32a7af215
commit 99a5436ed5
6 changed files with 211 additions and 37 deletions

View file

@ -290,6 +290,7 @@ litellm_settings:
- Full List: presidio, lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation
- `default_on`: bool, will run on all llm requests when true
- `logging_only`: Optional[bool], if true, run guardrail only on logged output, not on the actual LLM API call. Currently only supported for presidio pii masking. Requires `default_on` to be True as well.
- `callback_args`: Optional[Dict[str, Dict]]: If set, pass in init args for that specific guardrail
Example:
@ -299,6 +300,7 @@ litellm_settings:
- prompt_injection: # your custom name for guardrail
callbacks: [lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation] # litellm callbacks to use
default_on: true # will run on all llm requests when true
callback_args: {"lakera_prompt_injection": {"moderation_check": "pre_call"}}
- hide_secrets:
callbacks: [hide_secrets]
default_on: true

View file

@ -10,7 +10,7 @@ import sys, os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import Literal, List, Dict
from typing import Literal, List, Dict, Optional, Union
import litellm, sys
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
@ -38,14 +38,38 @@ INPUT_POSITIONING_MAP = {
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
def __init__(self):
def __init__(
self, moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel"
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
self.lakera_api_key = os.environ["LAKERA_API_KEY"]
self.moderation_check = moderation_check
pass
#### CALL HOOKS - proxy only ####
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: litellm.DualCache,
data: Dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
],
) -> Optional[Union[Exception, str, Dict]]:
if self.moderation_check == "in_parallel":
return None
return await super().async_pre_call_hook(
user_api_key_dict, cache, data, call_type
)
async def async_moderation_hook( ### 👈 KEY CHANGE ###
self,
@ -53,6 +77,8 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
if self.moderation_check == "pre_call":
return
if (
await should_proceed_based_on_metadata(

View file

@ -110,7 +110,12 @@ def initialize_callbacks_on_proxy(
+ CommonProxyErrors.not_premium_user.value
)
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation()
init_params = {}
if "lakera_prompt_injection" in callback_specific_params:
init_params = callback_specific_params["lakera_prompt_injection"]
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation(
**init_params
)
imported_list.append(lakera_moderations_object)
elif isinstance(callback, str) and callback == "aporio_prompt_injection":
from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio

View file

@ -38,6 +38,8 @@ def initialize_guardrails(
verbose_proxy_logger.debug(guardrail.guardrail_name)
verbose_proxy_logger.debug(guardrail.default_on)
callback_specific_params.update(guardrail.callback_args)
if guardrail.default_on is True:
# add these to litellm callbacks if they don't exist
for callback in guardrail.callbacks:
@ -46,7 +48,7 @@ def initialize_guardrails(
if guardrail.logging_only is True:
if callback == "presidio":
callback_specific_params["logging_only"] = True
callback_specific_params["logging_only"] = True # type: ignore
default_on_callbacks_list = list(default_on_callbacks)
if len(default_on_callbacks_list) > 0:

View file

@ -1,15 +1,15 @@
# What is this?
## This tests the Lakera AI integration
import json
import os
import sys
import json
from dotenv import load_dotenv
from fastapi import HTTPException, Request, Response
from fastapi.routing import APIRoute
from starlette.datastructures import URL
from fastapi import HTTPException
from litellm.types.guardrails import GuardrailItem
load_dotenv()
@ -19,6 +19,7 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import logging
from unittest.mock import patch
import pytest
@ -31,12 +32,10 @@ from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
)
from litellm.proxy.proxy_server import embeddings
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy.utils import hash_token
from unittest.mock import patch
verbose_proxy_logger.setLevel(logging.DEBUG)
def make_config_map(config: dict):
m = {}
for k, v in config.items():
@ -44,7 +43,19 @@ def make_config_map(config: dict):
m[k] = guardrail_item
return m
@patch('litellm.guardrail_name_config_map', make_config_map({'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True, 'enabled_roles': ['system', 'user']}}))
@patch(
"litellm.guardrail_name_config_map",
make_config_map(
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection", "prompt_injection_api_2"],
"default_on": True,
"enabled_roles": ["system", "user"],
}
}
),
)
@pytest.mark.asyncio
async def test_lakera_prompt_injection_detection():
"""
@ -78,7 +89,17 @@ async def test_lakera_prompt_injection_detection():
assert "Violated content safety policy" in str(http_exception)
@patch('litellm.guardrail_name_config_map', make_config_map({'prompt_injection': {'callbacks': ['lakera_prompt_injection'], 'default_on': True}}))
@patch(
"litellm.guardrail_name_config_map",
make_config_map(
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection"],
"default_on": True,
}
}
),
)
@pytest.mark.asyncio
async def test_lakera_safe_prompt():
"""
@ -152,10 +173,21 @@ async def test_moderations_on_embeddings():
print("got an exception", (str(e)))
assert "Violated content safety policy" in str(e.message)
@pytest.mark.asyncio
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
@patch("litellm.guardrail_name_config_map",
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True, "enabled_roles": ["user", "system"]}}))
@patch(
"litellm.guardrail_name_config_map",
new=make_config_map(
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection"],
"default_on": True,
"enabled_roles": ["user", "system"],
}
}
),
)
async def test_messages_for_disabled_role(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation()
data = {
@ -172,66 +204,119 @@ async def test_messages_for_disabled_role(spy_post):
{"role": "user", "content": "corgi sploot"},
]
}
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion")
await moderation.async_moderation_hook(
data=data, user_api_key_dict=None, call_type="completion"
)
_, kwargs = spy_post.call_args
assert json.loads(kwargs.get('data')) == expected_data
assert json.loads(kwargs.get("data")) == expected_data
@pytest.mark.asyncio
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
@patch("litellm.guardrail_name_config_map",
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}}))
@patch(
"litellm.guardrail_name_config_map",
new=make_config_map(
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection"],
"default_on": True,
}
}
),
)
@patch("litellm.add_function_to_prompt", False)
async def test_system_message_with_function_input(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation()
data = {
"messages": [
{"role": "system", "content": "Initial content."},
{"role": "user", "content": "Where are the best sunsets?", "tool_calls": [{"function": {"arguments": "Function args"}}]}
{
"role": "user",
"content": "Where are the best sunsets?",
"tool_calls": [{"function": {"arguments": "Function args"}}],
},
]
}
expected_data = {
"input": [
{"role": "system", "content": "Initial content. Function Input: Function args"},
{
"role": "system",
"content": "Initial content. Function Input: Function args",
},
{"role": "user", "content": "Where are the best sunsets?"},
]
}
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion")
await moderation.async_moderation_hook(
data=data, user_api_key_dict=None, call_type="completion"
)
_, kwargs = spy_post.call_args
assert json.loads(kwargs.get('data')) == expected_data
assert json.loads(kwargs.get("data")) == expected_data
@pytest.mark.asyncio
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
@patch("litellm.guardrail_name_config_map",
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}}))
@patch(
"litellm.guardrail_name_config_map",
new=make_config_map(
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection"],
"default_on": True,
}
}
),
)
@patch("litellm.add_function_to_prompt", False)
async def test_multi_message_with_function_input(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation()
data = {
"messages": [
{"role": "system", "content": "Initial content.", "tool_calls": [{"function": {"arguments": "Function args"}}]},
{"role": "user", "content": "Strawberry", "tool_calls": [{"function": {"arguments": "Function args"}}]}
{
"role": "system",
"content": "Initial content.",
"tool_calls": [{"function": {"arguments": "Function args"}}],
},
{
"role": "user",
"content": "Strawberry",
"tool_calls": [{"function": {"arguments": "Function args"}}],
},
]
}
expected_data = {
"input": [
{"role": "system", "content": "Initial content. Function Input: Function args Function args"},
{
"role": "system",
"content": "Initial content. Function Input: Function args Function args",
},
{"role": "user", "content": "Strawberry"},
]
}
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion")
await moderation.async_moderation_hook(
data=data, user_api_key_dict=None, call_type="completion"
)
_, kwargs = spy_post.call_args
assert json.loads(kwargs.get('data')) == expected_data
assert json.loads(kwargs.get("data")) == expected_data
@pytest.mark.asyncio
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
@patch("litellm.guardrail_name_config_map",
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}}))
@patch(
"litellm.guardrail_name_config_map",
new=make_config_map(
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection"],
"default_on": True,
}
}
),
)
async def test_message_ordering(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation()
data = {
@ -249,8 +334,57 @@ async def test_message_ordering(spy_post):
]
}
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion")
await moderation.async_moderation_hook(
data=data, user_api_key_dict=None, call_type="completion"
)
_, kwargs = spy_post.call_args
assert json.loads(kwargs.get('data')) == expected_data
assert json.loads(kwargs.get("data")) == expected_data
@pytest.mark.asyncio
async def test_callback_specific_param_run_pre_call_check_lakera():
from typing import Dict, List, Optional, Union
import litellm
from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
os.environ["LAKERA_API_KEY"] = "7a91a1a6059da*******"
guardrails_config: List[Dict[str, GuardrailItemSpec]] = [
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection"],
"default_on": True,
"callback_args": {
"lakera_prompt_injection": {"moderation_check": "pre_call"}
},
}
}
]
litellm_settings = {"guardrails": guardrails_config}
assert len(litellm.guardrail_name_config_map) == 0
initialize_guardrails(
guardrails_config=guardrails_config,
premium_user=True,
config_file_path="",
litellm_settings=litellm_settings,
)
assert len(litellm.guardrail_name_config_map) == 1
prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None
print("litellm callbacks={}".format(litellm.callbacks))
for callback in litellm.callbacks:
if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation):
prompt_injection_obj = callback
else:
print("Type of callback={}".format(type(callback)))
assert prompt_injection_obj is not None
assert hasattr(prompt_injection_obj, "moderation_check")
assert prompt_injection_obj.moderation_check == "pre_call"

View file

@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Optional
from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict
from typing_extensions import Required, TypedDict
@ -33,6 +33,7 @@ class GuardrailItemSpec(TypedDict, total=False):
default_on: bool
logging_only: Optional[bool]
enabled_roles: Optional[List[Role]]
callback_args: Dict[str, Dict]
class GuardrailItem(BaseModel):
@ -40,7 +41,9 @@ class GuardrailItem(BaseModel):
default_on: bool
logging_only: Optional[bool]
guardrail_name: str
callback_args: Dict[str, Dict]
enabled_roles: Optional[List[Role]]
model_config = ConfigDict(use_enum_values=True)
def __init__(
@ -50,6 +53,7 @@ class GuardrailItem(BaseModel):
default_on: bool = False,
logging_only: Optional[bool] = None,
enabled_roles: Optional[List[Role]] = default_roles,
callback_args: Dict[str, Dict] = {},
):
super().__init__(
callbacks=callbacks,
@ -57,4 +61,5 @@ class GuardrailItem(BaseModel):
logging_only=logging_only,
guardrail_name=guardrail_name,
enabled_roles=enabled_roles,
callback_args=callback_args,
)