forked from phoenix/litellm-mirror
feat(lakera_ai.py): control running prompt injection between pre-call and in_parallel
This commit is contained in:
parent
a32a7af215
commit
99a5436ed5
6 changed files with 211 additions and 37 deletions
|
@ -290,6 +290,7 @@ litellm_settings:
|
||||||
- Full List: presidio, lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation
|
- Full List: presidio, lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation
|
||||||
- `default_on`: bool, will run on all llm requests when true
|
- `default_on`: bool, will run on all llm requests when true
|
||||||
- `logging_only`: Optional[bool], if true, run guardrail only on logged output, not on the actual LLM API call. Currently only supported for presidio pii masking. Requires `default_on` to be True as well.
|
- `logging_only`: Optional[bool], if true, run guardrail only on logged output, not on the actual LLM API call. Currently only supported for presidio pii masking. Requires `default_on` to be True as well.
|
||||||
|
- `callback_args`: Optional[Dict[str, Dict]]: If set, pass in init args for that specific guardrail
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
@ -299,6 +300,7 @@ litellm_settings:
|
||||||
- prompt_injection: # your custom name for guardrail
|
- prompt_injection: # your custom name for guardrail
|
||||||
callbacks: [lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation] # litellm callbacks to use
|
callbacks: [lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation] # litellm callbacks to use
|
||||||
default_on: true # will run on all llm requests when true
|
default_on: true # will run on all llm requests when true
|
||||||
|
callback_args: {"lakera_prompt_injection": {"moderation_check": "pre_call"}}
|
||||||
- hide_secrets:
|
- hide_secrets:
|
||||||
callbacks: [hide_secrets]
|
callbacks: [hide_secrets]
|
||||||
default_on: true
|
default_on: true
|
||||||
|
|
|
@ -10,7 +10,7 @@ import sys, os
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
from typing import Literal, List, Dict
|
from typing import Literal, List, Dict, Optional, Union
|
||||||
import litellm, sys
|
import litellm, sys
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
@ -38,14 +38,38 @@ INPUT_POSITIONING_MAP = {
|
||||||
|
|
||||||
|
|
||||||
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
|
class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
|
||||||
def __init__(self):
|
def __init__(
|
||||||
|
self, moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel"
|
||||||
|
):
|
||||||
self.async_handler = AsyncHTTPHandler(
|
self.async_handler = AsyncHTTPHandler(
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
)
|
)
|
||||||
self.lakera_api_key = os.environ["LAKERA_API_KEY"]
|
self.lakera_api_key = os.environ["LAKERA_API_KEY"]
|
||||||
|
self.moderation_check = moderation_check
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#### CALL HOOKS - proxy only ####
|
#### CALL HOOKS - proxy only ####
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: litellm.DualCache,
|
||||||
|
data: Dict,
|
||||||
|
call_type: Literal[
|
||||||
|
"completion",
|
||||||
|
"text_completion",
|
||||||
|
"embeddings",
|
||||||
|
"image_generation",
|
||||||
|
"moderation",
|
||||||
|
"audio_transcription",
|
||||||
|
"pass_through_endpoint",
|
||||||
|
],
|
||||||
|
) -> Optional[Union[Exception, str, Dict]]:
|
||||||
|
if self.moderation_check == "in_parallel":
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await super().async_pre_call_hook(
|
||||||
|
user_api_key_dict, cache, data, call_type
|
||||||
|
)
|
||||||
|
|
||||||
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||||
self,
|
self,
|
||||||
|
@ -53,6 +77,8 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
):
|
):
|
||||||
|
if self.moderation_check == "pre_call":
|
||||||
|
return
|
||||||
|
|
||||||
if (
|
if (
|
||||||
await should_proceed_based_on_metadata(
|
await should_proceed_based_on_metadata(
|
||||||
|
|
|
@ -110,7 +110,12 @@ def initialize_callbacks_on_proxy(
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
+ CommonProxyErrors.not_premium_user.value
|
||||||
)
|
)
|
||||||
|
|
||||||
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation()
|
init_params = {}
|
||||||
|
if "lakera_prompt_injection" in callback_specific_params:
|
||||||
|
init_params = callback_specific_params["lakera_prompt_injection"]
|
||||||
|
lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation(
|
||||||
|
**init_params
|
||||||
|
)
|
||||||
imported_list.append(lakera_moderations_object)
|
imported_list.append(lakera_moderations_object)
|
||||||
elif isinstance(callback, str) and callback == "aporio_prompt_injection":
|
elif isinstance(callback, str) and callback == "aporio_prompt_injection":
|
||||||
from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio
|
from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio
|
||||||
|
|
|
@ -38,6 +38,8 @@ def initialize_guardrails(
|
||||||
verbose_proxy_logger.debug(guardrail.guardrail_name)
|
verbose_proxy_logger.debug(guardrail.guardrail_name)
|
||||||
verbose_proxy_logger.debug(guardrail.default_on)
|
verbose_proxy_logger.debug(guardrail.default_on)
|
||||||
|
|
||||||
|
callback_specific_params.update(guardrail.callback_args)
|
||||||
|
|
||||||
if guardrail.default_on is True:
|
if guardrail.default_on is True:
|
||||||
# add these to litellm callbacks if they don't exist
|
# add these to litellm callbacks if they don't exist
|
||||||
for callback in guardrail.callbacks:
|
for callback in guardrail.callbacks:
|
||||||
|
@ -46,7 +48,7 @@ def initialize_guardrails(
|
||||||
|
|
||||||
if guardrail.logging_only is True:
|
if guardrail.logging_only is True:
|
||||||
if callback == "presidio":
|
if callback == "presidio":
|
||||||
callback_specific_params["logging_only"] = True
|
callback_specific_params["logging_only"] = True # type: ignore
|
||||||
|
|
||||||
default_on_callbacks_list = list(default_on_callbacks)
|
default_on_callbacks_list = list(default_on_callbacks)
|
||||||
if len(default_on_callbacks_list) > 0:
|
if len(default_on_callbacks_list) > 0:
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## This tests the Lakera AI integration
|
## This tests the Lakera AI integration
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import json
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi import HTTPException, Request, Response
|
from fastapi import HTTPException, Request, Response
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
from starlette.datastructures import URL
|
from starlette.datastructures import URL
|
||||||
from fastapi import HTTPException
|
|
||||||
from litellm.types.guardrails import GuardrailItem
|
from litellm.types.guardrails import GuardrailItem
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -19,6 +19,7 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import logging
|
import logging
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -31,12 +32,10 @@ from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
||||||
)
|
)
|
||||||
from litellm.proxy.proxy_server import embeddings
|
from litellm.proxy.proxy_server import embeddings
|
||||||
from litellm.proxy.utils import ProxyLogging, hash_token
|
from litellm.proxy.utils import ProxyLogging, hash_token
|
||||||
from litellm.proxy.utils import hash_token
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
|
|
||||||
verbose_proxy_logger.setLevel(logging.DEBUG)
|
verbose_proxy_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
def make_config_map(config: dict):
|
def make_config_map(config: dict):
|
||||||
m = {}
|
m = {}
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
|
@ -44,7 +43,19 @@ def make_config_map(config: dict):
|
||||||
m[k] = guardrail_item
|
m[k] = guardrail_item
|
||||||
return m
|
return m
|
||||||
|
|
||||||
@patch('litellm.guardrail_name_config_map', make_config_map({'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True, 'enabled_roles': ['system', 'user']}}))
|
|
||||||
|
@patch(
|
||||||
|
"litellm.guardrail_name_config_map",
|
||||||
|
make_config_map(
|
||||||
|
{
|
||||||
|
"prompt_injection": {
|
||||||
|
"callbacks": ["lakera_prompt_injection", "prompt_injection_api_2"],
|
||||||
|
"default_on": True,
|
||||||
|
"enabled_roles": ["system", "user"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_lakera_prompt_injection_detection():
|
async def test_lakera_prompt_injection_detection():
|
||||||
"""
|
"""
|
||||||
|
@ -78,7 +89,17 @@ async def test_lakera_prompt_injection_detection():
|
||||||
assert "Violated content safety policy" in str(http_exception)
|
assert "Violated content safety policy" in str(http_exception)
|
||||||
|
|
||||||
|
|
||||||
@patch('litellm.guardrail_name_config_map', make_config_map({'prompt_injection': {'callbacks': ['lakera_prompt_injection'], 'default_on': True}}))
|
@patch(
|
||||||
|
"litellm.guardrail_name_config_map",
|
||||||
|
make_config_map(
|
||||||
|
{
|
||||||
|
"prompt_injection": {
|
||||||
|
"callbacks": ["lakera_prompt_injection"],
|
||||||
|
"default_on": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_lakera_safe_prompt():
|
async def test_lakera_safe_prompt():
|
||||||
"""
|
"""
|
||||||
|
@ -152,17 +173,28 @@ async def test_moderations_on_embeddings():
|
||||||
print("got an exception", (str(e)))
|
print("got an exception", (str(e)))
|
||||||
assert "Violated content safety policy" in str(e.message)
|
assert "Violated content safety policy" in str(e.message)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
|
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
|
||||||
@patch("litellm.guardrail_name_config_map",
|
@patch(
|
||||||
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True, "enabled_roles": ["user", "system"]}}))
|
"litellm.guardrail_name_config_map",
|
||||||
|
new=make_config_map(
|
||||||
|
{
|
||||||
|
"prompt_injection": {
|
||||||
|
"callbacks": ["lakera_prompt_injection"],
|
||||||
|
"default_on": True,
|
||||||
|
"enabled_roles": ["user", "system"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
async def test_messages_for_disabled_role(spy_post):
|
async def test_messages_for_disabled_role(spy_post):
|
||||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
||||||
data = {
|
data = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "assistant", "content": "This should be ignored." },
|
{"role": "assistant", "content": "This should be ignored."},
|
||||||
{"role": "user", "content": "corgi sploot"},
|
{"role": "user", "content": "corgi sploot"},
|
||||||
{"role": "system", "content": "Initial content." },
|
{"role": "system", "content": "Initial content."},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,66 +204,119 @@ async def test_messages_for_disabled_role(spy_post):
|
||||||
{"role": "user", "content": "corgi sploot"},
|
{"role": "user", "content": "corgi sploot"},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion")
|
await moderation.async_moderation_hook(
|
||||||
|
data=data, user_api_key_dict=None, call_type="completion"
|
||||||
|
)
|
||||||
|
|
||||||
_, kwargs = spy_post.call_args
|
_, kwargs = spy_post.call_args
|
||||||
assert json.loads(kwargs.get('data')) == expected_data
|
assert json.loads(kwargs.get("data")) == expected_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
|
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
|
||||||
@patch("litellm.guardrail_name_config_map",
|
@patch(
|
||||||
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}}))
|
"litellm.guardrail_name_config_map",
|
||||||
|
new=make_config_map(
|
||||||
|
{
|
||||||
|
"prompt_injection": {
|
||||||
|
"callbacks": ["lakera_prompt_injection"],
|
||||||
|
"default_on": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
@patch("litellm.add_function_to_prompt", False)
|
@patch("litellm.add_function_to_prompt", False)
|
||||||
async def test_system_message_with_function_input(spy_post):
|
async def test_system_message_with_function_input(spy_post):
|
||||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
||||||
data = {
|
data = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "Initial content." },
|
{"role": "system", "content": "Initial content."},
|
||||||
{"role": "user", "content": "Where are the best sunsets?", "tool_calls": [{"function": {"arguments": "Function args"}}]}
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Where are the best sunsets?",
|
||||||
|
"tool_calls": [{"function": {"arguments": "Function args"}}],
|
||||||
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
expected_data = {
|
expected_data = {
|
||||||
"input": [
|
"input": [
|
||||||
{"role": "system", "content": "Initial content. Function Input: Function args"},
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Initial content. Function Input: Function args",
|
||||||
|
},
|
||||||
{"role": "user", "content": "Where are the best sunsets?"},
|
{"role": "user", "content": "Where are the best sunsets?"},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion")
|
await moderation.async_moderation_hook(
|
||||||
|
data=data, user_api_key_dict=None, call_type="completion"
|
||||||
|
)
|
||||||
|
|
||||||
_, kwargs = spy_post.call_args
|
_, kwargs = spy_post.call_args
|
||||||
assert json.loads(kwargs.get('data')) == expected_data
|
assert json.loads(kwargs.get("data")) == expected_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
|
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
|
||||||
@patch("litellm.guardrail_name_config_map",
|
@patch(
|
||||||
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}}))
|
"litellm.guardrail_name_config_map",
|
||||||
|
new=make_config_map(
|
||||||
|
{
|
||||||
|
"prompt_injection": {
|
||||||
|
"callbacks": ["lakera_prompt_injection"],
|
||||||
|
"default_on": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
@patch("litellm.add_function_to_prompt", False)
|
@patch("litellm.add_function_to_prompt", False)
|
||||||
async def test_multi_message_with_function_input(spy_post):
|
async def test_multi_message_with_function_input(spy_post):
|
||||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
||||||
data = {
|
data = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "Initial content.", "tool_calls": [{"function": {"arguments": "Function args"}}]},
|
{
|
||||||
{"role": "user", "content": "Strawberry", "tool_calls": [{"function": {"arguments": "Function args"}}]}
|
"role": "system",
|
||||||
|
"content": "Initial content.",
|
||||||
|
"tool_calls": [{"function": {"arguments": "Function args"}}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Strawberry",
|
||||||
|
"tool_calls": [{"function": {"arguments": "Function args"}}],
|
||||||
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
expected_data = {
|
expected_data = {
|
||||||
"input": [
|
"input": [
|
||||||
{"role": "system", "content": "Initial content. Function Input: Function args Function args"},
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Initial content. Function Input: Function args Function args",
|
||||||
|
},
|
||||||
{"role": "user", "content": "Strawberry"},
|
{"role": "user", "content": "Strawberry"},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion")
|
await moderation.async_moderation_hook(
|
||||||
|
data=data, user_api_key_dict=None, call_type="completion"
|
||||||
|
)
|
||||||
|
|
||||||
_, kwargs = spy_post.call_args
|
_, kwargs = spy_post.call_args
|
||||||
assert json.loads(kwargs.get('data')) == expected_data
|
assert json.loads(kwargs.get("data")) == expected_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
|
@patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post")
|
||||||
@patch("litellm.guardrail_name_config_map",
|
@patch(
|
||||||
new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}}))
|
"litellm.guardrail_name_config_map",
|
||||||
|
new=make_config_map(
|
||||||
|
{
|
||||||
|
"prompt_injection": {
|
||||||
|
"callbacks": ["lakera_prompt_injection"],
|
||||||
|
"default_on": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
async def test_message_ordering(spy_post):
|
async def test_message_ordering(spy_post):
|
||||||
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
moderation = _ENTERPRISE_lakeraAI_Moderation()
|
||||||
data = {
|
data = {
|
||||||
|
@ -249,8 +334,57 @@ async def test_message_ordering(spy_post):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion")
|
await moderation.async_moderation_hook(
|
||||||
|
data=data, user_api_key_dict=None, call_type="completion"
|
||||||
|
)
|
||||||
|
|
||||||
_, kwargs = spy_post.call_args
|
_, kwargs = spy_post.call_args
|
||||||
assert json.loads(kwargs.get('data')) == expected_data
|
assert json.loads(kwargs.get("data")) == expected_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_specific_param_run_pre_call_check_lakera():
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation
|
||||||
|
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
||||||
|
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||||
|
|
||||||
|
os.environ["LAKERA_API_KEY"] = "7a91a1a6059da*******"
|
||||||
|
|
||||||
|
guardrails_config: List[Dict[str, GuardrailItemSpec]] = [
|
||||||
|
{
|
||||||
|
"prompt_injection": {
|
||||||
|
"callbacks": ["lakera_prompt_injection"],
|
||||||
|
"default_on": True,
|
||||||
|
"callback_args": {
|
||||||
|
"lakera_prompt_injection": {"moderation_check": "pre_call"}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
litellm_settings = {"guardrails": guardrails_config}
|
||||||
|
|
||||||
|
assert len(litellm.guardrail_name_config_map) == 0
|
||||||
|
initialize_guardrails(
|
||||||
|
guardrails_config=guardrails_config,
|
||||||
|
premium_user=True,
|
||||||
|
config_file_path="",
|
||||||
|
litellm_settings=litellm_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(litellm.guardrail_name_config_map) == 1
|
||||||
|
|
||||||
|
prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None
|
||||||
|
print("litellm callbacks={}".format(litellm.callbacks))
|
||||||
|
for callback in litellm.callbacks:
|
||||||
|
if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation):
|
||||||
|
prompt_injection_obj = callback
|
||||||
|
else:
|
||||||
|
print("Type of callback={}".format(type(callback)))
|
||||||
|
|
||||||
|
assert prompt_injection_obj is not None
|
||||||
|
|
||||||
|
assert hasattr(prompt_injection_obj, "moderation_check")
|
||||||
|
assert prompt_injection_obj.moderation_check == "pre_call"
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
@ -33,6 +33,7 @@ class GuardrailItemSpec(TypedDict, total=False):
|
||||||
default_on: bool
|
default_on: bool
|
||||||
logging_only: Optional[bool]
|
logging_only: Optional[bool]
|
||||||
enabled_roles: Optional[List[Role]]
|
enabled_roles: Optional[List[Role]]
|
||||||
|
callback_args: Dict[str, Dict]
|
||||||
|
|
||||||
|
|
||||||
class GuardrailItem(BaseModel):
|
class GuardrailItem(BaseModel):
|
||||||
|
@ -40,7 +41,9 @@ class GuardrailItem(BaseModel):
|
||||||
default_on: bool
|
default_on: bool
|
||||||
logging_only: Optional[bool]
|
logging_only: Optional[bool]
|
||||||
guardrail_name: str
|
guardrail_name: str
|
||||||
|
callback_args: Dict[str, Dict]
|
||||||
enabled_roles: Optional[List[Role]]
|
enabled_roles: Optional[List[Role]]
|
||||||
|
|
||||||
model_config = ConfigDict(use_enum_values=True)
|
model_config = ConfigDict(use_enum_values=True)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -50,6 +53,7 @@ class GuardrailItem(BaseModel):
|
||||||
default_on: bool = False,
|
default_on: bool = False,
|
||||||
logging_only: Optional[bool] = None,
|
logging_only: Optional[bool] = None,
|
||||||
enabled_roles: Optional[List[Role]] = default_roles,
|
enabled_roles: Optional[List[Role]] = default_roles,
|
||||||
|
callback_args: Dict[str, Dict] = {},
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
@ -57,4 +61,5 @@ class GuardrailItem(BaseModel):
|
||||||
logging_only=logging_only,
|
logging_only=logging_only,
|
||||||
guardrail_name=guardrail_name,
|
guardrail_name=guardrail_name,
|
||||||
enabled_roles=enabled_roles,
|
enabled_roles=enabled_roles,
|
||||||
|
callback_args=callback_args,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue