feat(guardrails.py): allow setting logging_only in guardrails_config for presidio pii masking integration

This commit is contained in:
Krrish Dholakia 2024-07-13 12:22:17 -07:00
parent d5f5415add
commit 6641683d66
7 changed files with 71 additions and 18 deletions

View file

@ -6,15 +6,13 @@ from pydantic import BaseModel, RootModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.types.guardrails import GuardrailItem
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
all_guardrails: List[GuardrailItem] = []
guardrail_name_config_map: Dict[str, GuardrailItem] = {}
def initialize_guardrails(
guardrails_config: list,
guardrails_config: List[Dict[str, GuardrailItemSpec]],
premium_user: bool,
config_file_path: str,
litellm_settings: dict,
@ -28,14 +26,14 @@ def initialize_guardrails(
{'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True}}
"""
for k, v in item.items():
guardrail_item = GuardrailItem(**v, guardrail_name=k)
all_guardrails.append(guardrail_item)
guardrail_name_config_map[k] = guardrail_item
litellm.guardrail_name_config_map[k] = guardrail_item
# set appropriate callbacks if they are default on
default_on_callbacks = set()
callback_specific_params = {}
for guardrail in all_guardrails:
verbose_proxy_logger.debug(guardrail.guardrail_name)
verbose_proxy_logger.debug(guardrail.default_on)
@ -46,6 +44,10 @@ def initialize_guardrails(
if callback not in litellm.callbacks:
default_on_callbacks.add(callback)
if guardrail.logging_only is True:
if callback == "presidio":
callback_specific_params["logging_only"] = True
default_on_callbacks_list = list(default_on_callbacks)
if len(default_on_callbacks_list) > 0:
initialize_callbacks_on_proxy(
@ -53,9 +55,10 @@ def initialize_guardrails(
premium_user=premium_user,
config_file_path=config_file_path,
litellm_settings=litellm_settings,
callback_specific_params=callback_specific_params,
)
return guardrail_name_config_map
return litellm.guardrail_name_config_map
except Exception as e:
verbose_proxy_logger.error(
"error initializing guardrails {}\n{}".format(