forked from phoenix/litellm-mirror
feat(lakera_ai.py): control running prompt injection between pre-call and in_parallel
This commit is contained in:
parent
a32a7af215
commit
99a5436ed5
6 changed files with 211 additions and 37 deletions
|
@ -290,6 +290,7 @@ litellm_settings:
|
|||
- Full List: presidio, lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation
|
||||
- `default_on`: bool, will run on all llm requests when true
|
||||
- `logging_only`: Optional[bool], if true, run guardrail only on logged output, not on the actual LLM API call. Currently only supported for presidio pii masking. Requires `default_on` to be True as well.
|
||||
- `callback_args`: Optional[Dict[str, Dict]]: If set, pass in init args for that specific guardrail
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -299,6 +300,7 @@ litellm_settings:
|
|||
- prompt_injection: # your custom name for guardrail
|
||||
callbacks: [lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation] # litellm callbacks to use
|
||||
default_on: true # will run on all llm requests when true
|
||||
callback_args: {"lakera_prompt_injection": {"moderation_check": "pre_call"}}
|
||||
- hide_secrets:
|
||||
callbacks: [hide_secrets]
|
||||
default_on: true
|
||||
|
|
|
@ -10,7 +10,7 @@ import sys, os
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from typing import Literal, List, Dict
|
||||
from typing import Literal, List, Dict, Optional, Union
|
||||
import litellm, sys
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
@ -38,14 +38,38 @@ INPUT_POSITIONING_MAP = {
|
|||
|
||||
|
||||
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
|
||||
def __init__(self):
|
||||
def __init__(
|
||||
self, moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel"
|
||||
):
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
self.lakera_api_key = os.environ["LAKERA_API_KEY"]
|
||||
self.moderation_check = moderation_check
|
||||
pass
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: litellm.DualCache,
|
||||
data: Dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"pass_through_endpoint",
|
||||
],
|
||||
) -> Optional[Union[Exception, str, Dict]]:
|
||||
if self.moderation_check == "in_parallel":
|
||||
return None
|
||||
|
||||
return await super().async_pre_call_hook(
|
||||
user_api_key_dict, cache, data, call_type
|
||||
)
|
||||
|
||||
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||
self,
|
||||
|
@ -53,6 +77,8 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
|
|||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||
):
|
||||
if self.moderation_check == "pre_call":
|
||||
return
|
||||
|
||||
if (
|
||||
await should_proceed_based_on_metadata(
|
||||
|
|
|
@ -110,7 +110,12 @@ def initialize_callbacks_on_proxy(
|
|||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation()
|
||||
init_params = {}
|
||||
if "lakera_prompt_injection" in callback_specific_params:
|
||||
init_params = callback_specific_params["lakera_prompt_injection"]
|
||||
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation(
|
||||
**init_params
|
||||
)
|
||||
imported_list.append(lakera_moderations_object)
|
||||
elif isinstance(callback, str) and callback == "aporio_prompt_injection":
|
||||
from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio
|
||||
|
|
|
@ -38,6 +38,8 @@ def initialize_guardrails(
|
|||
verbose_proxy_logger.debug(guardrail.guardrail_name)
|
||||
verbose_proxy_logger.debug(guardrail.default_on)
|
||||
|
||||
callback_specific_params.update(guardrail.callback_args)
|
||||
|
||||
if guardrail.default_on is True:
|
||||
# add these to litellm callbacks if they don't exist
|
||||
for callback in guardrail.callbacks:
|
||||
|
@ -46,7 +48,7 @@ def initialize_guardrails(
|
|||
|
||||
if guardrail.logging_only is True:
|
||||
if callback == "presidio":
|
||||
callback_specific_params["logging_only"] = True
|
||||
callback_specific_params["logging_only"] = True # type: ignore
|
||||
|
||||
default_on_callbacks_list = list(default_on_callbacks)
|
||||
if len(default_on_callbacks_list) > 0:
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
# What is this?
|
||||
## This tests the Lakera AI integration
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import HTTPException, Request, Response
|
||||
from fastapi.routing import APIRoute
|
||||
from starlette.datastructures import URL
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.types.guardrails import GuardrailItem
|
||||
|
||||
load_dotenv()
|
||||
|
@ -19,6 +19,7 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -31,12 +32,10 @@ from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
|||
)
|
||||
from litellm.proxy.proxy_server import embeddings
|
||||
from litellm.proxy.utils import ProxyLogging, hash_token
|
||||
from litellm.proxy.utils import hash_token
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
verbose_proxy_logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def make_config_map(config: dict):
|
||||
m = {}
|
||||
for k, v in config.items():
|
||||
|
@ -44,7 +43,19 @@ def make_config_map(config: dict):
|
|||
m[k] = guardrail_item
|
||||
return m
|
||||
|
||||
@patch('litellm.guardrail_name_config_map', make_config_map({'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True, 'enabled_roles': ['system', 'user']}}))
|
||||
|
||||
@patch(
|
||||
"litellm.guardrail_name_config_map",
|
||||
make_config_map(
|
||||
{
|
||||
"prompt_injection": {
|
||||
"callbacks": ["lakera_prompt_injection", "prompt_injection_api_2"],
|
||||
"default_on": True,
|
||||
"enabled_roles": ["system", "user"],
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_lakera_prompt_injection_detection():
|
||||
"""
|
||||
|
@ -78,7 +89,17 @@ async def test_lakera_prompt_injection_detection():
|
|||
assert "Violated content safety policy" in str(http_exception)
|
||||
|
||||
|
||||
@patch('litellm.guardrail_name_config_map', make_config_map({'prompt_injection': {'callbacks': ['lakera_prompt_injection'], 'default_on': True}}))
|
||||
@patch(
|
||||
"litellm.guardrail_name_config_map",
|
||||
make_config_map(
|
||||
{
|
||||
"prompt_injection": {
|
||||
"callbacks": ["lakera_prompt_injection"],
|
||||
"default_on": True,
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_lakera_safe_prompt():
|
||||
"""
|
||||
|
@ -152,10 +173,21 @@ async def test_moderations_on_embeddings():
|
|||
print("got an exception", (str(e)))
|
||||
assert "Violated content safety policy" in str(e.message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
|
||||
@patch("litellm.guardrail_name_config_map",
|
||||
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True, "enabled_roles": ["user", "system"]}}))
|
||||
@patch(
|
||||
"litellm.guardrail_name_config_map",
|
||||
new=make_config_map(
|
||||
{
|
||||
"prompt_injection": {
|
||||
"callbacks": ["lakera_prompt_injection"],
|
||||
"default_on": True,
|
||||
"enabled_roles": ["user", "system"],
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
async def test_messages_for_disabled_role(spy_post):
|
||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
||||
data = {
|
||||
|
@ -172,66 +204,119 @@ async def test_messages_for_disabled_role(spy_post):
|
|||
{"role": "user", "content": "corgi sploot"},
|
||||
]
|
||||
}
|
||||
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion")
|
||||
await moderation.async_moderation_hook(
|
||||
data=data, user_api_key_dict=None, call_type="completion"
|
||||
)
|
||||
|
||||
_, kwargs = spy_post.call_args
|
||||
assert json.loads(kwargs.get('data')) == expected_data
|
||||
assert json.loads(kwargs.get("data")) == expected_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
|
||||
@patch("litellm.guardrail_name_config_map",
|
||||
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}}))
|
||||
@patch(
|
||||
"litellm.guardrail_name_config_map",
|
||||
new=make_config_map(
|
||||
{
|
||||
"prompt_injection": {
|
||||
"callbacks": ["lakera_prompt_injection"],
|
||||
"default_on": True,
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
@patch("litellm.add_function_to_prompt", False)
|
||||
async def test_system_message_with_function_input(spy_post):
|
||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "Initial content."},
|
||||
{"role": "user", "content": "Where are the best sunsets?", "tool_calls": [{"function": {"arguments": "Function args"}}]}
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Where are the best sunsets?",
|
||||
"tool_calls": [{"function": {"arguments": "Function args"}}],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
expected_data = {
|
||||
"input": [
|
||||
{"role": "system", "content": "Initial content. Function Input: Function args"},
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Initial content. Function Input: Function args",
|
||||
},
|
||||
{"role": "user", "content": "Where are the best sunsets?"},
|
||||
]
|
||||
}
|
||||
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion")
|
||||
await moderation.async_moderation_hook(
|
||||
data=data, user_api_key_dict=None, call_type="completion"
|
||||
)
|
||||
|
||||
_, kwargs = spy_post.call_args
|
||||
assert json.loads(kwargs.get('data')) == expected_data
|
||||
assert json.loads(kwargs.get("data")) == expected_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
|
||||
@patch("litellm.guardrail_name_config_map",
|
||||
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}}))
|
||||
@patch(
|
||||
"litellm.guardrail_name_config_map",
|
||||
new=make_config_map(
|
||||
{
|
||||
"prompt_injection": {
|
||||
"callbacks": ["lakera_prompt_injection"],
|
||||
"default_on": True,
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
@patch("litellm.add_function_to_prompt", False)
|
||||
async def test_multi_message_with_function_input(spy_post):
|
||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "Initial content.", "tool_calls": [{"function": {"arguments": "Function args"}}]},
|
||||
{"role": "user", "content": "Strawberry", "tool_calls": [{"function": {"arguments": "Function args"}}]}
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Initial content.",
|
||||
"tool_calls": [{"function": {"arguments": "Function args"}}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Strawberry",
|
||||
"tool_calls": [{"function": {"arguments": "Function args"}}],
|
||||
},
|
||||
]
|
||||
}
|
||||
expected_data = {
|
||||
"input": [
|
||||
{"role": "system", "content": "Initial content. Function Input: Function args Function args"},
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Initial content. Function Input: Function args Function args",
|
||||
},
|
||||
{"role": "user", "content": "Strawberry"},
|
||||
]
|
||||
}
|
||||
|
||||
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion")
|
||||
await moderation.async_moderation_hook(
|
||||
data=data, user_api_key_dict=None, call_type="completion"
|
||||
)
|
||||
|
||||
_, kwargs = spy_post.call_args
|
||||
assert json.loads(kwargs.get('data')) == expected_data
|
||||
assert json.loads(kwargs.get("data")) == expected_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
|
||||
@patch("litellm.guardrail_name_config_map",
|
||||
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}}))
|
||||
@patch(
|
||||
"litellm.guardrail_name_config_map",
|
||||
new=make_config_map(
|
||||
{
|
||||
"prompt_injection": {
|
||||
"callbacks": ["lakera_prompt_injection"],
|
||||
"default_on": True,
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
async def test_message_ordering(spy_post):
|
||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
||||
data = {
|
||||
|
@ -249,8 +334,57 @@ async def test_message_ordering(spy_post):
|
|||
]
|
||||
}
|
||||
|
||||
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion")
|
||||
await moderation.async_moderation_hook(
|
||||
data=data, user_api_key_dict=None, call_type="completion"
|
||||
)
|
||||
|
||||
_, kwargs = spy_post.call_args
|
||||
assert json.loads(kwargs.get('data')) == expected_data
|
||||
assert json.loads(kwargs.get("data")) == expected_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_specific_param_run_pre_call_check_lakera():
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation
|
||||
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
||||
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||
|
||||
os.environ["LAKERA_API_KEY"] = "7a91a1a6059da*******"
|
||||
|
||||
guardrails_config: List[Dict[str, GuardrailItemSpec]] = [
|
||||
{
|
||||
"prompt_injection": {
|
||||
"callbacks": ["lakera_prompt_injection"],
|
||||
"default_on": True,
|
||||
"callback_args": {
|
||||
"lakera_prompt_injection": {"moderation_check": "pre_call"}
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
litellm_settings = {"guardrails": guardrails_config}
|
||||
|
||||
assert len(litellm.guardrail_name_config_map) == 0
|
||||
initialize_guardrails(
|
||||
guardrails_config=guardrails_config,
|
||||
premium_user=True,
|
||||
config_file_path="",
|
||||
litellm_settings=litellm_settings,
|
||||
)
|
||||
|
||||
assert len(litellm.guardrail_name_config_map) == 1
|
||||
|
||||
prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None
|
||||
print("litellm callbacks={}".format(litellm.callbacks))
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation):
|
||||
prompt_injection_obj = callback
|
||||
else:
|
||||
print("Type of callback={}".format(type(callback)))
|
||||
|
||||
assert prompt_injection_obj is not None
|
||||
|
||||
assert hasattr(prompt_injection_obj, "moderation_check")
|
||||
assert prompt_injection_obj.moderation_check == "pre_call"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
@ -33,6 +33,7 @@ class GuardrailItemSpec(TypedDict, total=False):
|
|||
default_on: bool
|
||||
logging_only: Optional[bool]
|
||||
enabled_roles: Optional[List[Role]]
|
||||
callback_args: Dict[str, Dict]
|
||||
|
||||
|
||||
class GuardrailItem(BaseModel):
|
||||
|
@ -40,7 +41,9 @@ class GuardrailItem(BaseModel):
|
|||
default_on: bool
|
||||
logging_only: Optional[bool]
|
||||
guardrail_name: str
|
||||
callback_args: Dict[str, Dict]
|
||||
enabled_roles: Optional[List[Role]]
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
def __init__(
|
||||
|
@ -50,6 +53,7 @@ class GuardrailItem(BaseModel):
|
|||
default_on: bool = False,
|
||||
logging_only: Optional[bool] = None,
|
||||
enabled_roles: Optional[List[Role]] = default_roles,
|
||||
callback_args: Dict[str, Dict] = {},
|
||||
):
|
||||
super().__init__(
|
||||
callbacks=callbacks,
|
||||
|
@ -57,4 +61,5 @@ class GuardrailItem(BaseModel):
|
|||
logging_only=logging_only,
|
||||
guardrail_name=guardrail_name,
|
||||
enabled_roles=enabled_roles,
|
||||
callback_args=callback_args,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue