allow init guardrails with output parsing logic

This commit is contained in:
Ishaan Jaff 2024-09-04 14:40:35 -07:00
parent f1111f9a1b
commit 4ab8e52bfa
2 changed files with 29 additions and 66 deletions

View file

@ -48,11 +48,12 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
# Class variables or attributes # Class variables or attributes
def __init__( def __init__(
self, self,
logging_only: Optional[bool] = None,
mock_testing: bool = False, mock_testing: bool = False,
mock_redacted_text: Optional[dict] = None, mock_redacted_text: Optional[dict] = None,
presidio_analyzer_api_base: Optional[str] = None, presidio_analyzer_api_base: Optional[str] = None,
presidio_anonymizer_api_base: Optional[str] = None, presidio_anonymizer_api_base: Optional[str] = None,
output_parse_pii: Optional[bool] = False,
presidio_ad_hoc_recognizers: Optional[str] = None,
**kwargs, **kwargs,
): ):
self.pii_tokens: dict = ( self.pii_tokens: dict = (
@ -60,11 +61,11 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
) # mapping of PII token to original text - only used with Presidio `replace` operation ) # mapping of PII token to original text - only used with Presidio `replace` operation
self.mock_redacted_text = mock_redacted_text self.mock_redacted_text = mock_redacted_text
self.logging_only = logging_only self.output_parse_pii = output_parse_pii or False
if mock_testing is True: # for testing purposes only if mock_testing is True: # for testing purposes only
return return
ad_hoc_recognizers = litellm.presidio_ad_hoc_recognizers ad_hoc_recognizers = presidio_ad_hoc_recognizers
if ad_hoc_recognizers is not None: if ad_hoc_recognizers is not None:
try: try:
with open(ad_hoc_recognizers, "r") as file: with open(ad_hoc_recognizers, "r") as file:
@ -225,69 +226,9 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
""" """
try: try:
if (
self.logging_only is True
): # only modify the logging obj data (done by async_logging_hook)
return data
permissions = user_api_key_dict.permissions
output_parse_pii = permissions.get(
"output_parse_pii", litellm.output_parse_pii
) # allow key to turn on/off output parsing for pii
no_pii = permissions.get(
"no-pii", None
) # allow key to turn on/off pii masking (if user is allowed to set pii controls, then they can override the key defaults)
if no_pii is None:
# check older way of turning on/off pii
no_pii = not permissions.get("pii", True)
content_safety = data.get("content_safety", None) content_safety = data.get("content_safety", None)
verbose_proxy_logger.debug("content_safety: %s", content_safety) verbose_proxy_logger.debug("content_safety: %s", content_safety)
## Request-level turn on/off PII controls ##
if content_safety is not None and isinstance(content_safety, dict):
# pii masking ##
if (
content_safety.get("no-pii", None) is not None
and content_safety.get("no-pii") == True
):
# check if user allowed to turn this off
if permissions.get("allow_pii_controls", False) == False:
raise HTTPException(
status_code=400,
detail={
"error": "Not allowed to set PII controls per request"
},
)
else: # user allowed to turn off pii masking
no_pii = content_safety.get("no-pii")
if not isinstance(no_pii, bool):
raise HTTPException(
status_code=400,
detail={"error": "no_pii needs to be a boolean value"},
)
## pii output parsing ##
if content_safety.get("output_parse_pii", None) is not None:
# check if user allowed to turn this off
if permissions.get("allow_pii_controls", False) == False:
raise HTTPException(
status_code=400,
detail={
"error": "Not allowed to set PII controls per request"
},
)
else: # user allowed to turn on/off pii output parsing
output_parse_pii = content_safety.get("output_parse_pii")
if not isinstance(output_parse_pii, bool):
raise HTTPException(
status_code=400,
detail={
"error": "output_parse_pii needs to be a boolean value"
},
)
if no_pii is True: # turn off pii masking
return data
presidio_config = self.get_presidio_settings_from_request_data(data) presidio_config = self.get_presidio_settings_from_request_data(data)
if call_type == "completion": # /chat/completions requests if call_type == "completion": # /chat/completions requests
@ -299,7 +240,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
tasks.append( tasks.append(
self.check_pii( self.check_pii(
text=m["content"], text=m["content"],
output_parse_pii=output_parse_pii, output_parse_pii=self.output_parse_pii,
presidio_config=presidio_config, presidio_config=presidio_config,
) )
) )
@ -372,9 +313,9 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
Output parse the response object to replace the masked tokens with user sent values Output parse the response object to replace the masked tokens with user sent values
""" """
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"PII Masking Args: litellm.output_parse_pii={litellm.output_parse_pii}; type of response={type(response)}" f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
) )
if litellm.output_parse_pii == False: if self.output_parse_pii == False:
return response return response
if isinstance(response, ModelResponse) and not isinstance( if isinstance(response, ModelResponse) and not isinstance(

View file

@ -11,6 +11,7 @@ from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_pr
# v2 implementation # v2 implementation
from litellm.types.guardrails import ( from litellm.types.guardrails import (
Guardrail, Guardrail,
GuardrailEventHooks,
GuardrailItem, GuardrailItem,
GuardrailItemSpec, GuardrailItemSpec,
LakeraCategoryThresholds, LakeraCategoryThresholds,
@ -104,6 +105,10 @@ def init_guardrails_v2(
api_base=litellm_params_data.get("api_base"), api_base=litellm_params_data.get("api_base"),
guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"), guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"),
guardrailVersion=litellm_params_data.get("guardrailVersion"), guardrailVersion=litellm_params_data.get("guardrailVersion"),
output_parse_pii=litellm_params_data.get("output_parse_pii"),
presidio_ad_hoc_recognizers=litellm_params_data.get(
"presidio_ad_hoc_recognizers"
),
) )
if ( if (
@ -173,7 +178,24 @@ def init_guardrails_v2(
_presidio_callback = _OPTIONAL_PresidioPIIMasking( _presidio_callback = _OPTIONAL_PresidioPIIMasking(
guardrail_name=guardrail["guardrail_name"], guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"], event_hook=litellm_params["mode"],
output_parse_pii=litellm_params["output_parse_pii"],
presidio_ad_hoc_recognizers=litellm_params[
"presidio_ad_hoc_recognizers"
],
) )
if litellm_params["output_parse_pii"] is True:
_success_callback = _OPTIONAL_PresidioPIIMasking(
output_parse_pii=True,
guardrail_name=guardrail["guardrail_name"],
event_hook=GuardrailEventHooks.post_call.value,
presidio_ad_hoc_recognizers=litellm_params[
"presidio_ad_hoc_recognizers"
],
)
litellm.callbacks.append(_success_callback) # type: ignore
litellm.callbacks.append(_presidio_callback) # type: ignore litellm.callbacks.append(_presidio_callback) # type: ignore
elif ( elif (
isinstance(litellm_params["guardrail"], str) isinstance(litellm_params["guardrail"], str)