mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
(Feat) - allow setting default_on
guardrails (#7973)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s
* test_default_on_guardrail * update debug on custom guardrail * refactor guardrails init * guardrail registry * allow switching guardrails default_on * fix circle import issue * fix bedrock applying guardrails where content is a list * fix unused import * docs default on guardrail * docs fix per api key
This commit is contained in:
parent
bedd48e83e
commit
ed283bc5b4
10 changed files with 292 additions and 325 deletions
|
@ -1,4 +1,5 @@
|
|||
import importlib
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import litellm
|
||||
|
@ -9,14 +10,14 @@ from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_pr
|
|||
# v2 implementation
|
||||
from litellm.types.guardrails import (
|
||||
Guardrail,
|
||||
GuardrailEventHooks,
|
||||
GuardrailItem,
|
||||
GuardrailItemSpec,
|
||||
LakeraCategoryThresholds,
|
||||
LitellmParams,
|
||||
SupportedGuardrailIntegrations,
|
||||
)
|
||||
|
||||
from .guardrail_registry import guardrail_registry
|
||||
|
||||
all_guardrails: List[GuardrailItem] = []
|
||||
|
||||
|
||||
|
@ -83,23 +84,18 @@ Map guardrail_name: <pre_call>, <post_call>, during_call
|
|||
"""
|
||||
|
||||
|
||||
def init_guardrails_v2( # noqa: PLR0915
|
||||
def init_guardrails_v2(
|
||||
all_guardrails: List[Dict],
|
||||
config_file_path: Optional[str] = None,
|
||||
):
|
||||
# Convert the loaded data to the TypedDict structure
|
||||
guardrail_list = []
|
||||
|
||||
# Parse each guardrail and replace environment variables
|
||||
for guardrail in all_guardrails:
|
||||
|
||||
# Init litellm params for guardrail
|
||||
litellm_params_data = guardrail["litellm_params"]
|
||||
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
|
||||
|
||||
_litellm_params_kwargs = {
|
||||
k: litellm_params_data[k] if k in litellm_params_data else None
|
||||
for k in LitellmParams.__annotations__.keys()
|
||||
k: litellm_params_data.get(k) for k in LitellmParams.__annotations__.keys()
|
||||
}
|
||||
|
||||
litellm_params = LitellmParams(**_litellm_params_kwargs) # type: ignore
|
||||
|
@ -113,157 +109,41 @@ def init_guardrails_v2( # noqa: PLR0915
|
|||
)
|
||||
litellm_params["category_thresholds"] = lakera_category_thresholds
|
||||
|
||||
if litellm_params["api_key"]:
|
||||
if litellm_params["api_key"].startswith("os.environ/"):
|
||||
litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore
|
||||
|
||||
if litellm_params["api_base"]:
|
||||
if litellm_params["api_base"].startswith("os.environ/"):
|
||||
litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore
|
||||
|
||||
# Init guardrail CustomLoggerClass
|
||||
if litellm_params["guardrail"] == SupportedGuardrailIntegrations.APORIA.value:
|
||||
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (
|
||||
AporiaGuardrail,
|
||||
)
|
||||
|
||||
_aporia_callback = AporiaGuardrail(
|
||||
api_base=litellm_params["api_base"],
|
||||
api_key=litellm_params["api_key"],
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=litellm_params["mode"],
|
||||
)
|
||||
litellm.callbacks.append(_aporia_callback) # type: ignore
|
||||
elif (
|
||||
litellm_params["guardrail"] == SupportedGuardrailIntegrations.BEDROCK.value
|
||||
if litellm_params["api_key"] and litellm_params["api_key"].startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import (
|
||||
BedrockGuardrail,
|
||||
)
|
||||
litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore
|
||||
|
||||
_bedrock_callback = BedrockGuardrail(
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=litellm_params["mode"],
|
||||
guardrailIdentifier=litellm_params["guardrailIdentifier"],
|
||||
guardrailVersion=litellm_params["guardrailVersion"],
|
||||
)
|
||||
litellm.callbacks.append(_bedrock_callback) # type: ignore
|
||||
elif litellm_params["guardrail"] == SupportedGuardrailIntegrations.LAKERA.value:
|
||||
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
|
||||
lakeraAI_Moderation,
|
||||
)
|
||||
|
||||
_lakera_callback = lakeraAI_Moderation(
|
||||
api_base=litellm_params["api_base"],
|
||||
api_key=litellm_params["api_key"],
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=litellm_params["mode"],
|
||||
category_thresholds=litellm_params.get("category_thresholds"),
|
||||
)
|
||||
litellm.callbacks.append(_lakera_callback) # type: ignore
|
||||
elif litellm_params["guardrail"] == SupportedGuardrailIntegrations.AIM.value:
|
||||
from litellm.proxy.guardrails.guardrail_hooks.aim import (
|
||||
AimGuardrail,
|
||||
)
|
||||
|
||||
_aim_callback = AimGuardrail(
|
||||
api_base=litellm_params["api_base"],
|
||||
api_key=litellm_params["api_key"],
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=litellm_params["mode"],
|
||||
)
|
||||
litellm.callbacks.append(_aim_callback) # type: ignore
|
||||
elif (
|
||||
litellm_params["guardrail"] == SupportedGuardrailIntegrations.PRESIDIO.value
|
||||
if litellm_params["api_base"] and litellm_params["api_base"].startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
|
||||
_OPTIONAL_PresidioPIIMasking,
|
||||
)
|
||||
litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore
|
||||
|
||||
_presidio_callback = _OPTIONAL_PresidioPIIMasking(
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=litellm_params["mode"],
|
||||
output_parse_pii=litellm_params["output_parse_pii"],
|
||||
presidio_ad_hoc_recognizers=litellm_params[
|
||||
"presidio_ad_hoc_recognizers"
|
||||
],
|
||||
mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
|
||||
)
|
||||
guardrail_type = litellm_params["guardrail"]
|
||||
|
||||
if litellm_params["output_parse_pii"] is True:
|
||||
_success_callback = _OPTIONAL_PresidioPIIMasking(
|
||||
output_parse_pii=True,
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=GuardrailEventHooks.post_call.value,
|
||||
presidio_ad_hoc_recognizers=litellm_params[
|
||||
"presidio_ad_hoc_recognizers"
|
||||
],
|
||||
)
|
||||
initializer = guardrail_registry.get(guardrail_type)
|
||||
|
||||
litellm.callbacks.append(_success_callback) # type: ignore
|
||||
|
||||
litellm.callbacks.append(_presidio_callback) # type: ignore
|
||||
elif (
|
||||
litellm_params["guardrail"]
|
||||
== SupportedGuardrailIntegrations.HIDE_SECRETS.value
|
||||
):
|
||||
from enterprise.enterprise_hooks.secret_detection import (
|
||||
_ENTERPRISE_SecretDetection,
|
||||
)
|
||||
|
||||
_secret_detection_object = _ENTERPRISE_SecretDetection(
|
||||
detect_secrets_config=litellm_params.get("detect_secrets_config"),
|
||||
event_hook=litellm_params["mode"],
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
)
|
||||
|
||||
litellm.callbacks.append(_secret_detection_object) # type: ignore
|
||||
elif (
|
||||
litellm_params["guardrail"]
|
||||
== SupportedGuardrailIntegrations.GURDRAILS_AI.value
|
||||
):
|
||||
from litellm.proxy.guardrails.guardrail_hooks.guardrails_ai import (
|
||||
GuardrailsAI,
|
||||
)
|
||||
|
||||
_guard_name = litellm_params.get("guard_name")
|
||||
if _guard_name is None:
|
||||
raise Exception(
|
||||
"GuardrailsAIException - Please pass the Guardrails AI guard name via 'litellm_params::guard_name'"
|
||||
)
|
||||
_guardrails_ai_callback = GuardrailsAI(
|
||||
api_base=litellm_params.get("api_base"),
|
||||
guard_name=_guard_name,
|
||||
guardrail_name=SupportedGuardrailIntegrations.GURDRAILS_AI.value,
|
||||
)
|
||||
|
||||
litellm.callbacks.append(_guardrails_ai_callback) # type: ignore
|
||||
elif (
|
||||
isinstance(litellm_params["guardrail"], str)
|
||||
and "." in litellm_params["guardrail"]
|
||||
):
|
||||
if config_file_path is None:
|
||||
if initializer:
|
||||
initializer(litellm_params, guardrail)
|
||||
elif isinstance(guardrail_type, str) and "." in guardrail_type:
|
||||
if not config_file_path:
|
||||
raise Exception(
|
||||
"GuardrailsAIException - Please pass the config_file_path to initialize_guardrails_v2"
|
||||
)
|
||||
import os
|
||||
|
||||
# Custom guardrail
|
||||
_guardrail = litellm_params["guardrail"]
|
||||
_file_name, _class_name = _guardrail.split(".")
|
||||
_file_name, _class_name = guardrail_type.split(".")
|
||||
verbose_proxy_logger.debug(
|
||||
"Initializing custom guardrail: %s, file_name: %s, class_name: %s",
|
||||
_guardrail,
|
||||
guardrail_type,
|
||||
_file_name,
|
||||
_class_name,
|
||||
)
|
||||
|
||||
directory = os.path.dirname(config_file_path)
|
||||
module_file_path = os.path.join(directory, _file_name)
|
||||
module_file_path += ".py"
|
||||
module_file_path = os.path.join(directory, _file_name) + ".py"
|
||||
|
||||
spec = importlib.util.spec_from_file_location(_class_name, module_file_path) # type: ignore
|
||||
if spec is None:
|
||||
if not spec:
|
||||
raise ImportError(
|
||||
f"Could not find a module specification for {module_file_path}"
|
||||
)
|
||||
|
@ -275,10 +155,11 @@ def init_guardrails_v2( # noqa: PLR0915
|
|||
_guardrail_callback = _guardrail_class(
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=litellm_params["mode"],
|
||||
default_on=litellm_params["default_on"],
|
||||
)
|
||||
litellm.callbacks.append(_guardrail_callback) # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unsupported guardrail: {litellm_params['guardrail']}")
|
||||
raise ValueError(f"Unsupported guardrail: {guardrail_type}")
|
||||
|
||||
parsed_guardrail = Guardrail(
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
|
@ -286,6 +167,5 @@ def init_guardrails_v2( # noqa: PLR0915
|
|||
)
|
||||
|
||||
guardrail_list.append(parsed_guardrail)
|
||||
guardrail["guardrail_name"]
|
||||
# pretty print guardrail_list in green
|
||||
|
||||
print(f"\nGuardrail List:{guardrail_list}\n") # noqa
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue