diff --git a/docs/my-website/docs/proxy/guardrails.md b/docs/my-website/docs/proxy/guardrails.md index 053fa8cab..f43b264e9 100644 --- a/docs/my-website/docs/proxy/guardrails.md +++ b/docs/my-website/docs/proxy/guardrails.md @@ -290,6 +290,7 @@ litellm_settings: - Full List: presidio, lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation - `default_on`: bool, will run on all llm requests when true - `logging_only`: Optional[bool], if true, run guardrail only on logged output, not on the actual LLM API call. Currently only supported for presidio pii masking. Requires `default_on` to be True as well. + - `callback_args`: Optional[Dict[str, Dict]]: If set, pass in init args for that specific guardrail Example: @@ -299,6 +300,7 @@ litellm_settings: - prompt_injection: # your custom name for guardrail callbacks: [lakera_prompt_injection, hide_secrets, llmguard_moderations, llamaguard_moderations, google_text_moderation] # litellm callbacks to use default_on: true # will run on all llm requests when true + callback_args: {"lakera_prompt_injection": {"moderation_check": "pre_call"}} - hide_secrets: callbacks: [hide_secrets] default_on: true diff --git a/enterprise/enterprise_hooks/lakera_ai.py b/enterprise/enterprise_hooks/lakera_ai.py index 75e346cdb..14ff595f9 100644 --- a/enterprise/enterprise_hooks/lakera_ai.py +++ b/enterprise/enterprise_hooks/lakera_ai.py @@ -10,7 +10,7 @@ import sys, os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -from typing import Literal, List, Dict +from typing import Literal, List, Dict, Optional, Union import litellm, sys from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger @@ -38,14 +38,38 @@ INPUT_POSITIONING_MAP = { class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): - def __init__(self): + def __init__( + self, moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel" + ): self.async_handler = AsyncHTTPHandler( timeout=httpx.Timeout(timeout=600.0, connect=5.0) ) self.lakera_api_key = os.environ["LAKERA_API_KEY"] + self.moderation_check = moderation_check pass #### CALL HOOKS - proxy only #### + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: litellm.DualCache, + data: Dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + ], + ) -> Optional[Union[Exception, str, Dict]]: + if self.moderation_check == "in_parallel": + return None + + return await super().async_pre_call_hook( + user_api_key_dict, cache, data, call_type + ) async def async_moderation_hook( ### 👈 KEY CHANGE ### self, @@ -53,6 +77,8 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): + if self.moderation_check == "pre_call": + return if ( await should_proceed_based_on_metadata( diff --git a/litellm/proxy/common_utils/init_callbacks.py b/litellm/proxy/common_utils/init_callbacks.py index 489f9b3a6..bd52efb19 100644 --- a/litellm/proxy/common_utils/init_callbacks.py +++ b/litellm/proxy/common_utils/init_callbacks.py @@ -110,7 +110,12 @@ def initialize_callbacks_on_proxy( + CommonProxyErrors.not_premium_user.value ) - lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation() + init_params = {} + if "lakera_prompt_injection" in callback_specific_params: + init_params = callback_specific_params["lakera_prompt_injection"] + lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation( + **init_params + ) imported_list.append(lakera_moderations_object) elif isinstance(callback, str) and callback == "aporio_prompt_injection": from enterprise.enterprise_hooks.aporio_ai import _ENTERPRISE_Aporio diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index 0afc17487..e98beb817 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -38,6 +38,8 @@ def initialize_guardrails( verbose_proxy_logger.debug(guardrail.guardrail_name) verbose_proxy_logger.debug(guardrail.default_on) + callback_specific_params.update(guardrail.callback_args) + if guardrail.default_on is True: # add these to litellm callbacks if they don't exist for callback in guardrail.callbacks: @@ -46,7 +48,7 @@ def initialize_guardrails( if guardrail.logging_only is True: if callback == "presidio": - callback_specific_params["logging_only"] = True + callback_specific_params["logging_only"] = True # type: ignore default_on_callbacks_list = list(default_on_callbacks) if len(default_on_callbacks_list) > 0: diff --git a/litellm/tests/test_lakera_ai_prompt_injection.py b/litellm/tests/test_lakera_ai_prompt_injection.py index c3839d4e0..ec1750ab2 100644 --- a/litellm/tests/test_lakera_ai_prompt_injection.py +++ b/litellm/tests/test_lakera_ai_prompt_injection.py @@ -1,15 +1,15 @@ # What is this? ## This tests the Lakera AI integration +import json import os import sys -import json from dotenv import load_dotenv from fastapi import HTTPException, Request, Response from fastapi.routing import APIRoute from starlette.datastructures import URL -from fastapi import HTTPException + from litellm.types.guardrails import GuardrailItem load_dotenv() @@ -19,6 +19,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import logging +from unittest.mock import patch import pytest @@ -31,12 +32,10 @@ from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import ( ) from litellm.proxy.proxy_server import embeddings from litellm.proxy.utils import ProxyLogging, hash_token -from litellm.proxy.utils import hash_token -from unittest.mock import patch - verbose_proxy_logger.setLevel(logging.DEBUG) + def make_config_map(config: dict): m = {} for k, v in config.items(): @@ -44,7 +43,19 @@ def make_config_map(config: dict): m[k] = guardrail_item return m -@patch('litellm.guardrail_name_config_map', make_config_map({'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True, 'enabled_roles': ['system', 'user']}})) + +@patch( + "litellm.guardrail_name_config_map", + make_config_map( + { + "prompt_injection": { + "callbacks": ["lakera_prompt_injection", "prompt_injection_api_2"], + "default_on": True, + "enabled_roles": ["system", "user"], + } + } + ), +) @pytest.mark.asyncio async def test_lakera_prompt_injection_detection(): """ @@ -78,7 +89,17 @@ async def test_lakera_prompt_injection_detection(): assert "Violated content safety policy" in str(http_exception) -@patch('litellm.guardrail_name_config_map', make_config_map({'prompt_injection': {'callbacks': ['lakera_prompt_injection'], 'default_on': True}})) +@patch( + "litellm.guardrail_name_config_map", + make_config_map( + { + "prompt_injection": { + "callbacks": ["lakera_prompt_injection"], + "default_on": True, + } + } + ), +) @pytest.mark.asyncio async def test_lakera_safe_prompt(): """ @@ -152,17 +173,28 @@ async def test_moderations_on_embeddings(): print("got an exception", (str(e))) assert "Violated content safety policy" in str(e.message) + @pytest.mark.asyncio @patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post") -@patch("litellm.guardrail_name_config_map", - new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True, "enabled_roles": ["user", "system"]}})) +@patch( + "litellm.guardrail_name_config_map", + new=make_config_map( + { + "prompt_injection": { + "callbacks": ["lakera_prompt_injection"], + "default_on": True, + "enabled_roles": ["user", "system"], + } + } + ), +) async def test_messages_for_disabled_role(spy_post): moderation = _ENTERPRISE_lakeraAI_Moderation() data = { "messages": [ - {"role": "assistant", "content": "This should be ignored." }, + {"role": "assistant", "content": "This should be ignored."}, {"role": "user", "content": "corgi sploot"}, - {"role": "system", "content": "Initial content." }, + {"role": "system", "content": "Initial content."}, ] } @@ -172,66 +204,119 @@ async def test_messages_for_disabled_role(spy_post): {"role": "user", "content": "corgi sploot"}, ] } - await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion") - + await moderation.async_moderation_hook( + data=data, user_api_key_dict=None, call_type="completion" + ) + _, kwargs = spy_post.call_args - assert json.loads(kwargs.get('data')) == expected_data + assert json.loads(kwargs.get("data")) == expected_data + @pytest.mark.asyncio @patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post") -@patch("litellm.guardrail_name_config_map", - new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}})) +@patch( + "litellm.guardrail_name_config_map", + new=make_config_map( + { + "prompt_injection": { + "callbacks": ["lakera_prompt_injection"], + "default_on": True, + } + } + ), +) @patch("litellm.add_function_to_prompt", False) async def test_system_message_with_function_input(spy_post): moderation = _ENTERPRISE_lakeraAI_Moderation() data = { "messages": [ - {"role": "system", "content": "Initial content." }, - {"role": "user", "content": "Where are the best sunsets?", "tool_calls": [{"function": {"arguments": "Function args"}}]} + {"role": "system", "content": "Initial content."}, + { + "role": "user", + "content": "Where are the best sunsets?", + "tool_calls": [{"function": {"arguments": "Function args"}}], + }, ] } expected_data = { "input": [ - {"role": "system", "content": "Initial content. Function Input: Function args"}, + { + "role": "system", + "content": "Initial content. Function Input: Function args", + }, {"role": "user", "content": "Where are the best sunsets?"}, ] } - await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion") + await moderation.async_moderation_hook( + data=data, user_api_key_dict=None, call_type="completion" + ) _, kwargs = spy_post.call_args - assert json.loads(kwargs.get('data')) == expected_data + assert json.loads(kwargs.get("data")) == expected_data + @pytest.mark.asyncio @patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post") -@patch("litellm.guardrail_name_config_map", - new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}})) +@patch( + "litellm.guardrail_name_config_map", + new=make_config_map( + { + "prompt_injection": { + "callbacks": ["lakera_prompt_injection"], + "default_on": True, + } + } + ), +) @patch("litellm.add_function_to_prompt", False) async def test_multi_message_with_function_input(spy_post): moderation = _ENTERPRISE_lakeraAI_Moderation() data = { "messages": [ - {"role": "system", "content": "Initial content.", "tool_calls": [{"function": {"arguments": "Function args"}}]}, - {"role": "user", "content": "Strawberry", "tool_calls": [{"function": {"arguments": "Function args"}}]} + { + "role": "system", + "content": "Initial content.", + "tool_calls": [{"function": {"arguments": "Function args"}}], + }, + { + "role": "user", + "content": "Strawberry", + "tool_calls": [{"function": {"arguments": "Function args"}}], + }, ] } expected_data = { "input": [ - {"role": "system", "content": "Initial content. Function Input: Function args Function args"}, + { + "role": "system", + "content": "Initial content. Function Input: Function args Function args", + }, {"role": "user", "content": "Strawberry"}, ] } - await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion") + await moderation.async_moderation_hook( + data=data, user_api_key_dict=None, call_type="completion" + ) _, kwargs = spy_post.call_args - assert json.loads(kwargs.get('data')) == expected_data + assert json.loads(kwargs.get("data")) == expected_data @pytest.mark.asyncio @patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post") -@patch("litellm.guardrail_name_config_map", - new=make_config_map({"prompt_injection": {'callbacks': ['lakera_prompt_injection'], 'default_on': True}})) +@patch( + "litellm.guardrail_name_config_map", + new=make_config_map( + { + "prompt_injection": { + "callbacks": ["lakera_prompt_injection"], + "default_on": True, + } + } + ), +) async def test_message_ordering(spy_post): moderation = _ENTERPRISE_lakeraAI_Moderation() data = { @@ -249,8 +334,57 @@ async def test_message_ordering(spy_post): ] } - await moderation.async_moderation_hook(data=data, user_api_key_dict=None, call_type="completion") + await moderation.async_moderation_hook( + data=data, user_api_key_dict=None, call_type="completion" + ) _, kwargs = spy_post.call_args - assert json.loads(kwargs.get('data')) == expected_data + assert json.loads(kwargs.get("data")) == expected_data + +@pytest.mark.asyncio +async def test_callback_specific_param_run_pre_call_check_lakera(): + from typing import Dict, List, Optional, Union + + import litellm + from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation + from litellm.proxy.guardrails.init_guardrails import initialize_guardrails + from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec + + os.environ["LAKERA_API_KEY"] = "7a91a1a6059da*******" + + guardrails_config: List[Dict[str, GuardrailItemSpec]] = [ + { + "prompt_injection": { + "callbacks": ["lakera_prompt_injection"], + "default_on": True, + "callback_args": { + "lakera_prompt_injection": {"moderation_check": "pre_call"} + }, + } + } + ] + litellm_settings = {"guardrails": guardrails_config} + + assert len(litellm.guardrail_name_config_map) == 0 + initialize_guardrails( + guardrails_config=guardrails_config, + premium_user=True, + config_file_path="", + litellm_settings=litellm_settings, + ) + + assert len(litellm.guardrail_name_config_map) == 1 + + prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None + print("litellm callbacks={}".format(litellm.callbacks)) + for callback in litellm.callbacks: + if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation): + prompt_injection_obj = callback + else: + print("Type of callback={}".format(type(callback))) + + assert prompt_injection_obj is not None + + assert hasattr(prompt_injection_obj, "moderation_check") + assert prompt_injection_obj.moderation_check == "pre_call" diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 27be12615..0296d8de4 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Optional +from typing import Dict, List, Optional from pydantic import BaseModel, ConfigDict from typing_extensions import Required, TypedDict @@ -33,6 +33,7 @@ class GuardrailItemSpec(TypedDict, total=False): default_on: bool logging_only: Optional[bool] enabled_roles: Optional[List[Role]] + callback_args: Dict[str, Dict] class GuardrailItem(BaseModel): @@ -40,7 +41,9 @@ class GuardrailItem(BaseModel): default_on: bool logging_only: Optional[bool] guardrail_name: str + callback_args: Dict[str, Dict] enabled_roles: Optional[List[Role]] + model_config = ConfigDict(use_enum_values=True) def __init__( @@ -50,6 +53,7 @@ class GuardrailItem(BaseModel): default_on: bool = False, logging_only: Optional[bool] = None, enabled_roles: Optional[List[Role]] = default_roles, + callback_args: Dict[str, Dict] = {}, ): super().__init__( callbacks=callbacks, @@ -57,4 +61,5 @@ class GuardrailItem(BaseModel): logging_only=logging_only, guardrail_name=guardrail_name, enabled_roles=enabled_roles, + callback_args=callback_args, )