(Feat) - allow setting default_on guardrails (#7973)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s

* test_default_on_guardrail

* update debug on custom guardrail

* refactor guardrails init

* guardrail registry

* allow switching guardrails default_on

* fix circle import issue

* fix bedrock applying guardrails where content is a list

* fix unused import

* docs default on guardrail

* docs fix per api key
This commit is contained in:
Ishaan Jaff 2025-01-24 10:14:05 -08:00 committed by GitHub
parent bedd48e83e
commit ed283bc5b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 292 additions and 325 deletions

View file

@ -1,4 +1,5 @@
import importlib
import os
from typing import Dict, List, Optional
import litellm
@ -9,14 +10,14 @@ from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_pr
# v2 implementation
from litellm.types.guardrails import (
Guardrail,
GuardrailEventHooks,
GuardrailItem,
GuardrailItemSpec,
LakeraCategoryThresholds,
LitellmParams,
SupportedGuardrailIntegrations,
)
from .guardrail_registry import guardrail_registry
all_guardrails: List[GuardrailItem] = []
@ -83,23 +84,18 @@ Map guardrail_name: <pre_call>, <post_call>, during_call
"""
def init_guardrails_v2( # noqa: PLR0915
def init_guardrails_v2(
all_guardrails: List[Dict],
config_file_path: Optional[str] = None,
):
# Convert the loaded data to the TypedDict structure
guardrail_list = []
# Parse each guardrail and replace environment variables
for guardrail in all_guardrails:
# Init litellm params for guardrail
litellm_params_data = guardrail["litellm_params"]
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
_litellm_params_kwargs = {
k: litellm_params_data[k] if k in litellm_params_data else None
for k in LitellmParams.__annotations__.keys()
k: litellm_params_data.get(k) for k in LitellmParams.__annotations__.keys()
}
litellm_params = LitellmParams(**_litellm_params_kwargs) # type: ignore
@ -113,157 +109,41 @@ def init_guardrails_v2( # noqa: PLR0915
)
litellm_params["category_thresholds"] = lakera_category_thresholds
if litellm_params["api_key"]:
if litellm_params["api_key"].startswith("os.environ/"):
litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore
if litellm_params["api_base"]:
if litellm_params["api_base"].startswith("os.environ/"):
litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore
# Init guardrail CustomLoggerClass
if litellm_params["guardrail"] == SupportedGuardrailIntegrations.APORIA.value:
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (
AporiaGuardrail,
)
_aporia_callback = AporiaGuardrail(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
)
litellm.callbacks.append(_aporia_callback) # type: ignore
elif (
litellm_params["guardrail"] == SupportedGuardrailIntegrations.BEDROCK.value
if litellm_params["api_key"] and litellm_params["api_key"].startswith(
"os.environ/"
):
from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import (
BedrockGuardrail,
)
litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore
_bedrock_callback = BedrockGuardrail(
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
guardrailIdentifier=litellm_params["guardrailIdentifier"],
guardrailVersion=litellm_params["guardrailVersion"],
)
litellm.callbacks.append(_bedrock_callback) # type: ignore
elif litellm_params["guardrail"] == SupportedGuardrailIntegrations.LAKERA.value:
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
lakeraAI_Moderation,
)
_lakera_callback = lakeraAI_Moderation(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
category_thresholds=litellm_params.get("category_thresholds"),
)
litellm.callbacks.append(_lakera_callback) # type: ignore
elif litellm_params["guardrail"] == SupportedGuardrailIntegrations.AIM.value:
from litellm.proxy.guardrails.guardrail_hooks.aim import (
AimGuardrail,
)
_aim_callback = AimGuardrail(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
)
litellm.callbacks.append(_aim_callback) # type: ignore
elif (
litellm_params["guardrail"] == SupportedGuardrailIntegrations.PRESIDIO.value
if litellm_params["api_base"] and litellm_params["api_base"].startswith(
"os.environ/"
):
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
_OPTIONAL_PresidioPIIMasking,
)
litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore
_presidio_callback = _OPTIONAL_PresidioPIIMasking(
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
output_parse_pii=litellm_params["output_parse_pii"],
presidio_ad_hoc_recognizers=litellm_params[
"presidio_ad_hoc_recognizers"
],
mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
)
guardrail_type = litellm_params["guardrail"]
if litellm_params["output_parse_pii"] is True:
_success_callback = _OPTIONAL_PresidioPIIMasking(
output_parse_pii=True,
guardrail_name=guardrail["guardrail_name"],
event_hook=GuardrailEventHooks.post_call.value,
presidio_ad_hoc_recognizers=litellm_params[
"presidio_ad_hoc_recognizers"
],
)
initializer = guardrail_registry.get(guardrail_type)
litellm.callbacks.append(_success_callback) # type: ignore
litellm.callbacks.append(_presidio_callback) # type: ignore
elif (
litellm_params["guardrail"]
== SupportedGuardrailIntegrations.HIDE_SECRETS.value
):
from enterprise.enterprise_hooks.secret_detection import (
_ENTERPRISE_SecretDetection,
)
_secret_detection_object = _ENTERPRISE_SecretDetection(
detect_secrets_config=litellm_params.get("detect_secrets_config"),
event_hook=litellm_params["mode"],
guardrail_name=guardrail["guardrail_name"],
)
litellm.callbacks.append(_secret_detection_object) # type: ignore
elif (
litellm_params["guardrail"]
== SupportedGuardrailIntegrations.GURDRAILS_AI.value
):
from litellm.proxy.guardrails.guardrail_hooks.guardrails_ai import (
GuardrailsAI,
)
_guard_name = litellm_params.get("guard_name")
if _guard_name is None:
raise Exception(
"GuardrailsAIException - Please pass the Guardrails AI guard name via 'litellm_params::guard_name'"
)
_guardrails_ai_callback = GuardrailsAI(
api_base=litellm_params.get("api_base"),
guard_name=_guard_name,
guardrail_name=SupportedGuardrailIntegrations.GURDRAILS_AI.value,
)
litellm.callbacks.append(_guardrails_ai_callback) # type: ignore
elif (
isinstance(litellm_params["guardrail"], str)
and "." in litellm_params["guardrail"]
):
if config_file_path is None:
if initializer:
initializer(litellm_params, guardrail)
elif isinstance(guardrail_type, str) and "." in guardrail_type:
if not config_file_path:
raise Exception(
"GuardrailsAIException - Please pass the config_file_path to initialize_guardrails_v2"
)
import os
# Custom guardrail
_guardrail = litellm_params["guardrail"]
_file_name, _class_name = _guardrail.split(".")
_file_name, _class_name = guardrail_type.split(".")
verbose_proxy_logger.debug(
"Initializing custom guardrail: %s, file_name: %s, class_name: %s",
_guardrail,
guardrail_type,
_file_name,
_class_name,
)
directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, _file_name)
module_file_path += ".py"
module_file_path = os.path.join(directory, _file_name) + ".py"
spec = importlib.util.spec_from_file_location(_class_name, module_file_path) # type: ignore
if spec is None:
if not spec:
raise ImportError(
f"Could not find a module specification for {module_file_path}"
)
@ -275,10 +155,11 @@ def init_guardrails_v2( # noqa: PLR0915
_guardrail_callback = _guardrail_class(
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_guardrail_callback) # type: ignore
else:
raise ValueError(f"Unsupported guardrail: {litellm_params['guardrail']}")
raise ValueError(f"Unsupported guardrail: {guardrail_type}")
parsed_guardrail = Guardrail(
guardrail_name=guardrail["guardrail_name"],
@ -286,6 +167,5 @@ def init_guardrails_v2( # noqa: PLR0915
)
guardrail_list.append(parsed_guardrail)
guardrail["guardrail_name"]
# pretty print guardrail_list in green
print(f"\nGuardrail List:{guardrail_list}\n") # noqa