allow init guardrails with output parsing logic

This commit is contained in:
Ishaan Jaff 2024-09-04 14:40:35 -07:00
parent f1111f9a1b
commit 4ab8e52bfa
2 changed files with 29 additions and 66 deletions

View file

@ -48,11 +48,12 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
# Class variables or attributes
def __init__(
self,
logging_only: Optional[bool] = None,
mock_testing: bool = False,
mock_redacted_text: Optional[dict] = None,
presidio_analyzer_api_base: Optional[str] = None,
presidio_anonymizer_api_base: Optional[str] = None,
output_parse_pii: Optional[bool] = False,
presidio_ad_hoc_recognizers: Optional[str] = None,
**kwargs,
):
self.pii_tokens: dict = (
@ -60,11 +61,11 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
) # mapping of PII token to original text - only used with Presidio `replace` operation
self.mock_redacted_text = mock_redacted_text
self.logging_only = logging_only
self.output_parse_pii = output_parse_pii or False
if mock_testing is True: # for testing purposes only
return
ad_hoc_recognizers = litellm.presidio_ad_hoc_recognizers
ad_hoc_recognizers = presidio_ad_hoc_recognizers
if ad_hoc_recognizers is not None:
try:
with open(ad_hoc_recognizers, "r") as file:
@ -225,69 +226,9 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
"""
try:
if (
self.logging_only is True
): # only modify the logging obj data (done by async_logging_hook)
return data
permissions = user_api_key_dict.permissions
output_parse_pii = permissions.get(
"output_parse_pii", litellm.output_parse_pii
) # allow key to turn on/off output parsing for pii
no_pii = permissions.get(
"no-pii", None
) # allow key to turn on/off pii masking (if user is allowed to set pii controls, then they can override the key defaults)
if no_pii is None:
# check older way of turning on/off pii
no_pii = not permissions.get("pii", True)
content_safety = data.get("content_safety", None)
verbose_proxy_logger.debug("content_safety: %s", content_safety)
## Request-level turn on/off PII controls ##
if content_safety is not None and isinstance(content_safety, dict):
# pii masking ##
if (
content_safety.get("no-pii", None) is not None
and content_safety.get("no-pii") == True
):
# check if user allowed to turn this off
if permissions.get("allow_pii_controls", False) == False:
raise HTTPException(
status_code=400,
detail={
"error": "Not allowed to set PII controls per request"
},
)
else: # user allowed to turn off pii masking
no_pii = content_safety.get("no-pii")
if not isinstance(no_pii, bool):
raise HTTPException(
status_code=400,
detail={"error": "no_pii needs to be a boolean value"},
)
## pii output parsing ##
if content_safety.get("output_parse_pii", None) is not None:
# check if user allowed to turn this off
if permissions.get("allow_pii_controls", False) == False:
raise HTTPException(
status_code=400,
detail={
"error": "Not allowed to set PII controls per request"
},
)
else: # user allowed to turn on/off pii output parsing
output_parse_pii = content_safety.get("output_parse_pii")
if not isinstance(output_parse_pii, bool):
raise HTTPException(
status_code=400,
detail={
"error": "output_parse_pii needs to be a boolean value"
},
)
if no_pii is True: # turn off pii masking
return data
presidio_config = self.get_presidio_settings_from_request_data(data)
if call_type == "completion": # /chat/completions requests
@ -299,7 +240,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
tasks.append(
self.check_pii(
text=m["content"],
output_parse_pii=output_parse_pii,
output_parse_pii=self.output_parse_pii,
presidio_config=presidio_config,
)
)
@ -372,9 +313,9 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
Output parse the response object to replace the masked tokens with user sent values
"""
verbose_proxy_logger.debug(
f"PII Masking Args: litellm.output_parse_pii={litellm.output_parse_pii}; type of response={type(response)}"
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
)
if litellm.output_parse_pii == False:
if self.output_parse_pii == False:
return response
if isinstance(response, ModelResponse) and not isinstance(

View file

@ -11,6 +11,7 @@ from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_pr
# v2 implementation
from litellm.types.guardrails import (
Guardrail,
GuardrailEventHooks,
GuardrailItem,
GuardrailItemSpec,
LakeraCategoryThresholds,
@ -104,6 +105,10 @@ def init_guardrails_v2(
api_base=litellm_params_data.get("api_base"),
guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"),
guardrailVersion=litellm_params_data.get("guardrailVersion"),
output_parse_pii=litellm_params_data.get("output_parse_pii"),
presidio_ad_hoc_recognizers=litellm_params_data.get(
"presidio_ad_hoc_recognizers"
),
)
if (
@ -173,7 +178,24 @@ def init_guardrails_v2(
_presidio_callback = _OPTIONAL_PresidioPIIMasking(
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
output_parse_pii=litellm_params["output_parse_pii"],
presidio_ad_hoc_recognizers=litellm_params[
"presidio_ad_hoc_recognizers"
],
)
if litellm_params["output_parse_pii"] is True:
_success_callback = _OPTIONAL_PresidioPIIMasking(
output_parse_pii=True,
guardrail_name=guardrail["guardrail_name"],
event_hook=GuardrailEventHooks.post_call.value,
presidio_ad_hoc_recognizers=litellm_params[
"presidio_ad_hoc_recognizers"
],
)
litellm.callbacks.append(_success_callback) # type: ignore
litellm.callbacks.append(_presidio_callback) # type: ignore
elif (
isinstance(litellm_params["guardrail"], str)