diff --git a/enterprise/enterprise_hooks/lakera_ai.py b/enterprise/enterprise_hooks/lakera_ai.py index 921859997..029f9dd9f 100644 --- a/enterprise/enterprise_hooks/lakera_ai.py +++ b/enterprise/enterprise_hooks/lakera_ai.py @@ -42,7 +42,7 @@ class LakeraCategories(TypedDict, total=False): prompt_injection: float -class _ENTERPRISE_lakeraAI_Moderation(CustomLogger): +class lakeraAI_Moderation(CustomLogger): def __init__( self, moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel", diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index 9516f72fb..44730825d 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -101,9 +101,7 @@ def initialize_callbacks_on_proxy( openai_moderations_object = _ENTERPRISE_OpenAI_Moderation() imported_list.append(openai_moderations_object) elif isinstance(callback, str) and callback == "lakera_prompt_injection": - from enterprise.enterprise_hooks.lakera_ai import ( - _ENTERPRISE_lakeraAI_Moderation, - ) + from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation if premium_user != True: raise Exception( @@ -114,9 +112,7 @@ def initialize_callbacks_on_proxy( init_params = {} if "lakera_prompt_injection" in callback_specific_params: init_params = callback_specific_params["lakera_prompt_injection"] - lakera_moderations_object = _ENTERPRISE_lakeraAI_Moderation( - **init_params - ) + lakera_moderations_object = lakeraAI_Moderation(**init_params) imported_list.append(lakera_moderations_object) elif isinstance(callback, str) and callback == "aporia_prompt_injection": from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import ( diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index 93b8e3d5c..787a58cd0 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -126,10 +126,10 @@ def init_guardrails_v2(all_guardrails: dict): litellm.callbacks.append(_aporia_callback) # type: ignore elif litellm_params["guardrail"] == "lakera": from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import ( - _ENTERPRISE_lakeraAI_Moderation, + lakeraAI_Moderation, ) - _lakera_callback = _ENTERPRISE_lakeraAI_Moderation() + _lakera_callback = lakeraAI_Moderation() litellm.callbacks.append(_lakera_callback) # type: ignore parsed_guardrail = Guardrail( diff --git a/litellm/tests/test_lakera_ai_prompt_injection.py b/litellm/tests/test_lakera_ai_prompt_injection.py index 01829468c..d010a52ae 100644 --- a/litellm/tests/test_lakera_ai_prompt_injection.py +++ b/litellm/tests/test_lakera_ai_prompt_injection.py @@ -27,9 +27,7 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import ( - _ENTERPRISE_lakeraAI_Moderation, -) +from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation from litellm.proxy.proxy_server import embeddings from litellm.proxy.utils import ProxyLogging, hash_token @@ -62,7 +60,7 @@ async def test_lakera_prompt_injection_detection(): Tests to see OpenAI Moderation raises an error for a flagged response """ - lakera_ai = _ENTERPRISE_lakeraAI_Moderation() + lakera_ai = lakeraAI_Moderation() _api_key = "sk-12345" _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) @@ -106,7 +104,7 @@ async def test_lakera_safe_prompt(): Nothing should get raised here """ - lakera_ai = _ENTERPRISE_lakeraAI_Moderation() + lakera_ai = lakeraAI_Moderation() _api_key = "sk-12345" _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) @@ -144,7 +142,7 @@ async def test_moderations_on_embeddings(): setattr(litellm.proxy.proxy_server, "llm_router", temp_router) api_route = APIRoute(path="/embeddings", endpoint=embeddings) - litellm.callbacks = [_ENTERPRISE_lakeraAI_Moderation()] + litellm.callbacks = [lakeraAI_Moderation()] request = Request( { "type": "http", @@ -189,7 +187,7 @@ async def test_moderations_on_embeddings(): ), ) async def test_messages_for_disabled_role(spy_post): - moderation = _ENTERPRISE_lakeraAI_Moderation() + moderation = lakeraAI_Moderation() data = { "messages": [ {"role": "assistant", "content": "This should be ignored."}, @@ -227,7 +225,7 @@ async def test_messages_for_disabled_role(spy_post): ) @patch("litellm.add_function_to_prompt", False) async def test_system_message_with_function_input(spy_post): - moderation = _ENTERPRISE_lakeraAI_Moderation() + moderation = lakeraAI_Moderation() data = { "messages": [ {"role": "system", "content": "Initial content."}, @@ -271,7 +269,7 @@ async def test_system_message_with_function_input(spy_post): ) @patch("litellm.add_function_to_prompt", False) async def test_multi_message_with_function_input(spy_post): - moderation = _ENTERPRISE_lakeraAI_Moderation() + moderation = lakeraAI_Moderation() data = { "messages": [ { @@ -318,7 +316,7 @@ async def test_multi_message_with_function_input(spy_post): ), ) async def test_message_ordering(spy_post): - moderation = _ENTERPRISE_lakeraAI_Moderation() + moderation = lakeraAI_Moderation() data = { "messages": [ {"role": "assistant", "content": "Assistant message."}, @@ -347,7 +345,7 @@ async def test_callback_specific_param_run_pre_call_check_lakera(): from typing import Dict, List, Optional, Union import litellm - from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation + from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation from litellm.proxy.guardrails.init_guardrails import initialize_guardrails from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec @@ -374,10 +372,10 @@ async def test_callback_specific_param_run_pre_call_check_lakera(): assert len(litellm.guardrail_name_config_map) == 1 - prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None + prompt_injection_obj: Optional[lakeraAI_Moderation] = None print("litellm callbacks={}".format(litellm.callbacks)) for callback in litellm.callbacks: - if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation): + if isinstance(callback, lakeraAI_Moderation): prompt_injection_obj = callback else: print("Type of callback={}".format(type(callback))) @@ -393,7 +391,7 @@ async def test_callback_specific_thresholds(): from typing import Dict, List, Optional, Union import litellm - from enterprise.enterprise_hooks.lakera_ai import _ENTERPRISE_lakeraAI_Moderation + from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation from litellm.proxy.guardrails.init_guardrails import initialize_guardrails from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec @@ -426,10 +424,10 @@ async def test_callback_specific_thresholds(): assert len(litellm.guardrail_name_config_map) == 1 - prompt_injection_obj: Optional[_ENTERPRISE_lakeraAI_Moderation] = None + prompt_injection_obj: Optional[lakeraAI_Moderation] = None print("litellm callbacks={}".format(litellm.callbacks)) for callback in litellm.callbacks: - if isinstance(callback, _ENTERPRISE_lakeraAI_Moderation): + if isinstance(callback, lakeraAI_Moderation): prompt_injection_obj = callback else: print("Type of callback={}".format(type(callback))) diff --git a/litellm/tests/test_proxy_setting_guardrails.py b/litellm/tests/test_proxy_setting_guardrails.py index 048951da0..e5baa1fa8 100644 --- a/litellm/tests/test_proxy_setting_guardrails.py +++ b/litellm/tests/test_proxy_setting_guardrails.py @@ -48,7 +48,7 @@ def test_active_callbacks(client): _active_callbacks = json_response["litellm.callbacks"] expected_callback_names = [ - "_ENTERPRISE_lakeraAI_Moderation", + "lakeraAI_Moderation", "_OPTIONAL_PromptInjectionDetectio", "_ENTERPRISE_SecretDetection", ]