diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index f0e2a9e2ec..4180bce2d3 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -1,3 +1,4 @@ +import importlib import traceback from typing import Dict, List, Literal @@ -161,6 +162,25 @@ def init_guardrails_v2(all_guardrails: dict): category_thresholds=litellm_params.get("category_thresholds"), ) litellm.callbacks.append(_lakera_callback) # type: ignore + elif ( + isinstance(litellm_params["guardrail"], str) + and "." in litellm_params["guardrail"] + ): + # Custom guardrail + _guardrail = litellm_params["guardrail"] + _file_name, _class_name = _guardrail.split(".") + verbose_proxy_logger.debug( + "Initializing custom guardrail: %s, file_name: %s, class_name: %s", + _guardrail, + _file_name, + _class_name, + ) + _guardrail_class = getattr(importlib.import_module(_file_name), _class_name) + _guardrail_callback = _guardrail_class( + guardrail_name=guardrail["guardrail_name"], + event_hook=litellm_params["mode"], + ) + litellm.callbacks.append(_guardrail_callback) # type: ignore parsed_guardrail = Guardrail( guardrail_name=guardrail["guardrail_name"], @@ -169,6 +189,5 @@ def init_guardrails_v2(all_guardrails: dict): guardrail_list.append(parsed_guardrail) guardrail_name = guardrail["guardrail_name"] - # pretty print guardrail_list in green print(f"\nGuardrail List:{guardrail_list}\n") # noqa