diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py index a3ac2ea863..047d1b6d37 100644 --- a/litellm/integrations/custom_guardrail.py +++ b/litellm/integrations/custom_guardrail.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Optional from litellm._logging import verbose_logger from litellm.integrations.custom_logger import CustomLogger @@ -7,9 +7,14 @@ from litellm.types.guardrails import GuardrailEventHooks class CustomGuardrail(CustomLogger): - def __init__(self, guardrail_name: str, event_hook: GuardrailEventHooks, **kwargs): + def __init__( + self, + guardrail_name: Optional[str] = None, + event_hook: Optional[GuardrailEventHooks] = None, + **kwargs + ): self.guardrail_name = guardrail_name - self.event_hook: GuardrailEventHooks = event_hook + self.event_hook: Optional[GuardrailEventHooks] = event_hook super().__init__(**kwargs) def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool: diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index 44730825df..243ae18135 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -101,13 +101,9 @@ def initialize_callbacks_on_proxy( openai_moderations_object = _ENTERPRISE_OpenAI_Moderation() imported_list.append(openai_moderations_object) elif isinstance(callback, str) and callback == "lakera_prompt_injection": - from enterprise.enterprise_hooks.lakera_ai import lakeraAI_Moderation - - if premium_user != True: - raise Exception( - "Trying to use LakeraAI Prompt Injection" - + CommonProxyErrors.not_premium_user.value - ) + from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import ( + lakeraAI_Moderation, + ) init_params = {} if "lakera_prompt_injection" in callback_specific_params: @@ -119,12 +115,6 @@ def initialize_callbacks_on_proxy( AporiaGuardrail, ) - if premium_user is not True: - raise Exception( - "Trying to use Aporia AI Guardrail" - + CommonProxyErrors.not_premium_user.value - ) - aporia_guardrail_object = AporiaGuardrail() imported_list.append(aporia_guardrail_object) elif isinstance(callback, str) and callback == "google_text_moderation": @@ -305,7 +295,11 @@ def get_applied_guardrails_header(request_data: Dict) -> Optional[Dict]: return None -def add_guardrail_to_applied_guardrails_header(request_data: Dict, guardrail_name: str): +def add_guardrail_to_applied_guardrails_header( + request_data: Dict, guardrail_name: Optional[str] +): + if guardrail_name is None: + return _metadata = request_data.get("metadata", None) or {} if "applied_guardrails" in _metadata: _metadata["applied_guardrails"].append(guardrail_name) diff --git a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py index c90802e542..8ee856da88 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py +++ b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py @@ -301,8 +301,9 @@ class lakeraAI_Moderation(CustomGuardrail): user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): - if self.moderation_check == "pre_call": - return + if self.event_hook is None: + if self.moderation_check == "pre_call": + return from litellm.types.guardrails import GuardrailEventHooks diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index dc27868d84..95267e6bb7 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -129,11 +129,17 @@ def init_guardrails_v2(all_guardrails: dict): lakeraAI_Moderation, ) - _lakera_callback = lakeraAI_Moderation() + _lakera_callback = lakeraAI_Moderation( + api_base=litellm_params["api_base"], + api_key=litellm_params["api_key"], + guardrail_name=guardrail["guardrail_name"], + event_hook=litellm_params["mode"], + ) litellm.callbacks.append(_lakera_callback) # type: ignore parsed_guardrail = Guardrail( - guardrail_name=guardrail["guardrail_name"], litellm_params=litellm_params + guardrail_name=guardrail["guardrail_name"], + litellm_params=litellm_params, ) guardrail_list.append(parsed_guardrail) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 2f0690e173..a52f97852d 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -15,7 +15,7 @@ guardrails: - guardrail_name: "lakera-pre-guard" litellm_params: guardrail: lakera # supported values: "aporia", "bedrock", "lakera" - mode: "pre_call" + mode: "during_call" api_key: os.environ/LAKERA_API_KEY api_base: os.environ/LAKERA_API_BASE \ No newline at end of file