feat(lakera_ai.py): control running prompt injection between pre-call and in_parallel

This commit is contained in:
Krrish Dholakia 2024-07-22 20:04:42 -07:00
parent a32a7af215
commit 99a5436ed5
6 changed files with 211 additions and 37 deletions

View file

@ -290,6 +290,7 @@ litellm_settings:
- Full List: presidio, lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation - Full List: presidio, lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation
- `default_on`: bool, will run on all llm requests when true - `default_on`: bool, will run on all llm requests when true
- `logging_only`: Optional[bool], if true, run guardrail only on logged output, not on the actual LLM API call. Currently only supported for presidio pii masking. Requires `default_on` to be True as well. - `logging_only`: Optional[bool], if true, run guardrail only on logged output, not on the actual LLM API call. Currently only supported for presidio pii masking. Requires `default_on` to be True as well.
- `callback_args`: Optional[Dict[str, Dict]]: If set, pass in init args for that specific guardrail
Example: Example:
@ -299,6 +300,7 @@ litellm_settings:
- prompt_injection: # your custom name for guardrail - prompt_injection: # your custom name for guardrail
callbacks: [lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation] # litellm callbacks to use callbacks: [lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation] # litellm callbacks to use
default_on: true # will run on all llm requests when true default_on: true # will run on all llm requests when true
callback_args: {"lakera_prompt_injection": {"moderation_check": "pre_call"}}
- hide_secrets: - hide_secrets:
callbacks: [hide_secrets] callbacks: [hide_secrets]
default_on: true default_on: true

View file

@ -10,7 +10,7 @@ import sys, os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from typing import Literal, List, Dict from typing import Literal, List, Dict, Optional, Union
import litellm, sys import litellm, sys
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -38,14 +38,38 @@ INPUT_POSITIONING_MAP = {
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
def __init__(self): def __init__(
self, moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel"
):
self.async_handler = AsyncHTTPHandler( self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) timeout=httpx.Timeout(timeout=600.0, connect=5.0)
) )
self.lakera_api_key = os.environ["LAKERA_API_KEY"] self.lakera_api_key = os.environ["LAKERA_API_KEY"]
self.moderation_check = moderation_check
pass pass
#### CALL HOOKS - proxy only #### #### CALL HOOKS - proxy only ####
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: litellm.DualCache,
data: Dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
],
) -> Optional[Union[Exception, str, Dict]]:
if self.moderation_check == "in_parallel":
return None
return await super().async_pre_call_hook(
user_api_key_dict, cache, data, call_type
)
async def async_moderation_hook( ### 👈 KEY CHANGE ### async def async_moderation_hook( ### 👈 KEY CHANGE ###
self, self,
@ -53,6 +77,8 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"], call_type: Literal["completion", "embeddings", "image_generation"],
): ):
if self.moderation_check == "pre_call":
return
if ( if (
await should_proceed_based_on_metadata( await should_proceed_based_on_metadata(

View file

@ -110,7 +110,12 @@ def initialize_callbacks_on_proxy(
+ CommonProxyErrors.not_premium_user.value + CommonProxyErrors.not_premium_user.value
) )
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation() init_params = {}
if "lakera_prompt_injection" in callback_specific_params:
init_params = callback_specific_params["lakera_prompt_injection"]
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation(
**init_params
)
imported_list.append(lakera_moderations_object) imported_list.append(lakera_moderations_object)
elif isinstance(callback, str) and callback == "aporio_prompt_injection": elif isinstance(callback, str) and callback == "aporio_prompt_injection":
from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio

View file

@ -38,6 +38,8 @@ def initialize_guardrails(
verbose_proxy_logger.debug(guardrail.guardrail_name) verbose_proxy_logger.debug(guardrail.guardrail_name)
verbose_proxy_logger.debug(guardrail.default_on) verbose_proxy_logger.debug(guardrail.default_on)
callback_specific_params.update(guardrail.callback_args)
if guardrail.default_on is True: if guardrail.default_on is True:
# add these to litellm callbacks if they don't exist # add these to litellm callbacks if they don't exist
for callback in guardrail.callbacks: for callback in guardrail.callbacks:
@ -46,7 +48,7 @@ def initialize_guardrails(
if guardrail.logging_only is True: if guardrail.logging_only is True:
if callback == "presidio": if callback == "presidio":
callback_specific_params["logging_only"] = True callback_specific_params["logging_only"] = True # type: ignore
default_on_callbacks_list = list(default_on_callbacks) default_on_callbacks_list = list(default_on_callbacks)
if len(default_on_callbacks_list) > 0: if len(default_on_callbacks_list) > 0:

View file

@ -1,15 +1,15 @@
# What is this? # What is this?
## This tests the Lakera AI integration ## This tests the Lakera AI integration
import json
import os import os
import sys import sys
import json
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import HTTPException, Request, Response from fastapi import HTTPException, Request, Response
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from starlette.datastructures import URL from starlette.datastructures import URL
from fastapi import HTTPException
from litellm.types.guardrails import GuardrailItem from litellm.types.guardrails import GuardrailItem
load_dotenv() load_dotenv()
@ -19,6 +19,7 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import logging import logging
from unittest.mock import patch
import pytest import pytest
@ -31,12 +32,10 @@ from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
) )
from litellm.proxy.proxy_server import embeddings from litellm.proxy.proxy_server import embeddings
from litellm.proxy.utils import ProxyLogging, hash_token from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy.utils import hash_token
from unittest.mock import patch
verbose_proxy_logger.setLevel(logging.DEBUG) verbose_proxy_logger.setLevel(logging.DEBUG)
def make_config_map(config: dict): def make_config_map(config: dict):
m = {} m = {}
for k, v in config.items(): for k, v in config.items():
@ -44,7 +43,19 @@ def make_config_map(config: dict):
m[k] = guardrail_item m[k] = guardrail_item
return m return m
@patch('litellm.guardrail_name_config_map', make_config_map({'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True, 'enabled_roles': ['system', 'user']}}))
@patch(
"litellm.guardrail_name_config_map",
make_config_map(
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection", "prompt_injection_api_2"],
"default_on": True,
"enabled_roles": ["system", "user"],
}
}
),
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_lakera_prompt_injection_detection(): async def test_lakera_prompt_injection_detection():
""" """
@ -78,7 +89,17 @@ async def test_lakera_prompt_injection_detection():
assert "Violated content safety policy" in str(http_exception) assert "Violated content safety policy" in str(http_exception)
@patch('litellm.guardrail_name_config_map', make_config_map({'prompt_injection': {'callbacks': ['lakera_prompt_injection'], 'default_on': True}})) @patch(
"litellm.guardrail_name_config_map",
make_config_map(
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection"],
"default_on": True,
}
}
),
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_lakera_safe_prompt(): async def test_lakera_safe_prompt():
""" """
@ -152,17 +173,28 @@ async def test_moderations_on_embeddings():
print("got an exception", (str(e))) print("got an exception", (str(e)))
assert "Violated content safety policy" in str(e.message) assert "Violated content safety policy" in str(e.message)
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post") @patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
@patch("litellm.guardrail_name_config_map", @patch(
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True, "enabled_roles": ["user", "system"]}})) "litellm.guardrail_name_config_map",
new=make_config_map(
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection"],
"default_on": True,
"enabled_roles": ["user", "system"],
}
}
),
)
async def test_messages_for_disabled_role(spy_post): async def test_messages_for_disabled_role(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation() moderation = _ENTERPRISE_lakeraAI_Moderation()
data = { data = {
"messages": [ "messages": [
{"role": "assistant", "content": "This should be ignored." }, {"role": "assistant", "content": "This should be ignored."},
{"role": "user", "content": "corgi sploot"}, {"role": "user", "content": "corgi sploot"},
{"role": "system", "content": "Initial content." }, {"role": "system", "content": "Initial content."},
] ]
} }
@ -172,66 +204,119 @@ async def test_messages_for_disabled_role(spy_post):
{"role": "user", "content": "corgi sploot"}, {"role": "user", "content": "corgi sploot"},
] ]
} }
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion") await moderation.async_moderation_hook(
data=data, user_api_key_dict=None, call_type="completion"
)
_, kwargs = spy_post.call_args _, kwargs = spy_post.call_args
assert json.loads(kwargs.get('data')) == expected_data assert json.loads(kwargs.get("data")) == expected_data
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post") @patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
@patch("litellm.guardrail_name_config_map", @patch(
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}})) "litellm.guardrail_name_config_map",
new=make_config_map(
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection"],
"default_on": True,
}
}
),
)
@patch("litellm.add_function_to_prompt", False) @patch("litellm.add_function_to_prompt", False)
async def test_system_message_with_function_input(spy_post): async def test_system_message_with_function_input(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation() moderation = _ENTERPRISE_lakeraAI_Moderation()
data = { data = {
"messages": [ "messages": [
{"role": "system", "content": "Initial content." }, {"role": "system", "content": "Initial content."},
{"role": "user", "content": "Where are the best sunsets?", "tool_calls": [{"function": {"arguments": "Function args"}}]} {
"role": "user",
"content": "Where are the best sunsets?",
"tool_calls": [{"function": {"arguments": "Function args"}}],
},
] ]
} }
expected_data = { expected_data = {
"input": [ "input": [
{"role": "system", "content": "Initial content. Function Input: Function args"}, {
"role": "system",
"content": "Initial content. Function Input: Function args",
},
{"role": "user", "content": "Where are the best sunsets?"}, {"role": "user", "content": "Where are the best sunsets?"},
] ]
} }
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion") await moderation.async_moderation_hook(
data=data, user_api_key_dict=None, call_type="completion"
)
_, kwargs = spy_post.call_args _, kwargs = spy_post.call_args
assert json.loads(kwargs.get('data')) == expected_data assert json.loads(kwargs.get("data")) == expected_data
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post") @patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
@patch("litellm.guardrail_name_config_map", @patch(
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}})) "litellm.guardrail_name_config_map",
new=make_config_map(
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection"],
"default_on": True,
}
}
),
)
@patch("litellm.add_function_to_prompt", False) @patch("litellm.add_function_to_prompt", False)
async def test_multi_message_with_function_input(spy_post): async def test_multi_message_with_function_input(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation() moderation = _ENTERPRISE_lakeraAI_Moderation()
data = { data = {
"messages": [ "messages": [
{"role": "system", "content": "Initial content.", "tool_calls": [{"function": {"arguments": "Function args"}}]}, {
{"role": "user", "content": "Strawberry", "tool_calls": [{"function": {"arguments": "Function args"}}]} "role": "system",
"content": "Initial content.",
"tool_calls": [{"function": {"arguments": "Function args"}}],
},
{
"role": "user",
"content": "Strawberry",
"tool_calls": [{"function": {"arguments": "Function args"}}],
},
] ]
} }
expected_data = { expected_data = {
"input": [ "input": [
{"role": "system", "content": "Initial content. Function Input: Function args Function args"}, {
"role": "system",
"content": "Initial content. Function Input: Function args Function args",
},
{"role": "user", "content": "Strawberry"}, {"role": "user", "content": "Strawberry"},
] ]
} }
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion") await moderation.async_moderation_hook(
data=data, user_api_key_dict=None, call_type="completion"
)
_, kwargs = spy_post.call_args _, kwargs = spy_post.call_args
assert json.loads(kwargs.get('data')) == expected_data assert json.loads(kwargs.get("data")) == expected_data
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post") @patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
@patch("litellm.guardrail_name_config_map", @patch(
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}})) "litellm.guardrail_name_config_map",
new=make_config_map(
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection"],
"default_on": True,
}
}
),
)
async def test_message_ordering(spy_post): async def test_message_ordering(spy_post):
moderation = _ENTERPRISE_lakeraAI_Moderation() moderation = _ENTERPRISE_lakeraAI_Moderation()
data = { data = {
@ -249,8 +334,57 @@ async def test_message_ordering(spy_post):
] ]
} }
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion") await moderation.async_moderation_hook(
data=data, user_api_key_dict=None, call_type="completion"
)
_, kwargs = spy_post.call_args _, kwargs = spy_post.call_args
assert json.loads(kwargs.get('data')) == expected_data assert json.loads(kwargs.get("data")) == expected_data
@pytest.mark.asyncio
async def test_callback_specific_param_run_pre_call_check_lakera():
from typing import Dict, List, Optional, Union
import litellm
from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
os.environ["LAKERA_API_KEY"] = "7a91a1a6059da*******"
guardrails_config: List[Dict[str, GuardrailItemSpec]] = [
{
"prompt_injection": {
"callbacks": ["lakera_prompt_injection"],
"default_on": True,
"callback_args": {
"lakera_prompt_injection": {"moderation_check": "pre_call"}
},
}
}
]
litellm_settings = {"guardrails": guardrails_config}
assert len(litellm.guardrail_name_config_map) == 0
initialize_guardrails(
guardrails_config=guardrails_config,
premium_user=True,
config_file_path="",
litellm_settings=litellm_settings,
)
assert len(litellm.guardrail_name_config_map) == 1
prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None
print("litellm callbacks={}".format(litellm.callbacks))
for callback in litellm.callbacks:
if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation):
prompt_injection_obj = callback
else:
print("Type of callback={}".format(type(callback)))
assert prompt_injection_obj is not None
assert hasattr(prompt_injection_obj, "moderation_check")
assert prompt_injection_obj.moderation_check == "pre_call"

View file

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import List, Optional from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict
@ -33,6 +33,7 @@ class GuardrailItemSpec(TypedDict, total=False):
default_on: bool default_on: bool
logging_only: Optional[bool] logging_only: Optional[bool]
enabled_roles: Optional[List[Role]] enabled_roles: Optional[List[Role]]
callback_args: Dict[str, Dict]
class GuardrailItem(BaseModel): class GuardrailItem(BaseModel):
@ -40,7 +41,9 @@ class GuardrailItem(BaseModel):
default_on: bool default_on: bool
logging_only: Optional[bool] logging_only: Optional[bool]
guardrail_name: str guardrail_name: str
callback_args: Dict[str, Dict]
enabled_roles: Optional[List[Role]] enabled_roles: Optional[List[Role]]
model_config = ConfigDict(use_enum_values=True) model_config = ConfigDict(use_enum_values=True)
def __init__( def __init__(
@ -50,6 +53,7 @@ class GuardrailItem(BaseModel):
default_on: bool = False, default_on: bool = False,
logging_only: Optional[bool] = None, logging_only: Optional[bool] = None,
enabled_roles: Optional[List[Role]] = default_roles, enabled_roles: Optional[List[Role]] = default_roles,
callback_args: Dict[str, Dict] = {},
): ):
super().__init__( super().__init__(
callbacks=callbacks, callbacks=callbacks,
@ -57,4 +61,5 @@ class GuardrailItem(BaseModel):
logging_only=logging_only, logging_only=logging_only,
guardrail_name=guardrail_name, guardrail_name=guardrail_name,
enabled_roles=enabled_roles, enabled_roles=enabled_roles,
callback_args=callback_args,
) )