mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
migrate presidio to new guardrails
This commit is contained in:
parent
be3c7b401e
commit
1276c10338
3 changed files with 379 additions and 0 deletions
363
litellm/proxy/guardrails/guardrail_hooks/presidio.py
Normal file
363
litellm/proxy/guardrails/guardrail_hooks/presidio.py
Normal file
|
@ -0,0 +1,363 @@
|
|||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | PII Masking |
|
||||
# | with Microsoft Presidio |
|
||||
# | https://github.com/BerriAI/litellm/issues/ |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Tell us how we can improve! - Krrish & Ishaan
|
||||
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm # noqa: E401
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.llms.custom_httpx.http_handler import _get_async_httpx_client
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.utils import (
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
StreamingChoices,
|
||||
get_formatted_prompt,
|
||||
)
|
||||
|
||||
|
||||
class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||
user_api_key_cache = None
|
||||
ad_hoc_recognizers = None
|
||||
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self,
|
||||
logging_only: Optional[bool] = None,
|
||||
mock_testing: bool = False,
|
||||
mock_redacted_text: Optional[dict] = None,
|
||||
presidio_analyzer_api_base: Optional[str] = None,
|
||||
presidio_anonymizer_api_base: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.pii_tokens: dict = (
|
||||
{}
|
||||
) # mapping of PII token to original text - only used with Presidio `replace` operation
|
||||
|
||||
self.mock_redacted_text = mock_redacted_text
|
||||
self.logging_only = logging_only
|
||||
if mock_testing is True: # for testing purposes only
|
||||
return
|
||||
|
||||
ad_hoc_recognizers = litellm.presidio_ad_hoc_recognizers
|
||||
if ad_hoc_recognizers is not None:
|
||||
try:
|
||||
with open(ad_hoc_recognizers, "r") as file:
|
||||
self.ad_hoc_recognizers = json.load(file)
|
||||
except FileNotFoundError:
|
||||
raise Exception(f"File not found. file_path={ad_hoc_recognizers}")
|
||||
except json.JSONDecodeError as e:
|
||||
raise Exception(
|
||||
f"Error decoding JSON file: {str(e)}, file_path={ad_hoc_recognizers}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}"
|
||||
)
|
||||
self.async_http_client = _get_async_httpx_client()
|
||||
self.validate_environment(
|
||||
presidio_analyzer_api_base=presidio_analyzer_api_base,
|
||||
presidio_anonymizer_api_base=presidio_anonymizer_api_base,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
presidio_analyzer_api_base: Optional[str] = None,
|
||||
presidio_anonymizer_api_base: Optional[str] = None,
|
||||
):
|
||||
self.presidio_analyzer_api_base: Optional[str] = (
|
||||
presidio_analyzer_api_base
|
||||
or litellm.get_secret("PRESIDIO_ANALYZER_API_BASE", None)
|
||||
)
|
||||
self.presidio_anonymizer_api_base: Optional[
|
||||
str
|
||||
] = presidio_anonymizer_api_base or litellm.get_secret(
|
||||
"PRESIDIO_ANONYMIZER_API_BASE", None
|
||||
) # type: ignore
|
||||
|
||||
if self.presidio_analyzer_api_base is None:
|
||||
raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment")
|
||||
if not self.presidio_analyzer_api_base.endswith("/"):
|
||||
self.presidio_analyzer_api_base += "/"
|
||||
if not (
|
||||
self.presidio_analyzer_api_base.startswith("http://")
|
||||
or self.presidio_analyzer_api_base.startswith("https://")
|
||||
):
|
||||
# add http:// if unset, assume communicating over private network - e.g. render
|
||||
self.presidio_analyzer_api_base = (
|
||||
"http://" + self.presidio_analyzer_api_base
|
||||
)
|
||||
|
||||
if self.presidio_anonymizer_api_base is None:
|
||||
raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment")
|
||||
if not self.presidio_anonymizer_api_base.endswith("/"):
|
||||
self.presidio_anonymizer_api_base += "/"
|
||||
if not (
|
||||
self.presidio_anonymizer_api_base.startswith("http://")
|
||||
or self.presidio_anonymizer_api_base.startswith("https://")
|
||||
):
|
||||
# add http:// if unset, assume communicating over private network - e.g. render
|
||||
self.presidio_anonymizer_api_base = (
|
||||
"http://" + self.presidio_anonymizer_api_base
|
||||
)
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
pass
|
||||
|
||||
async def check_pii(self, text: str, output_parse_pii: bool) -> str: # type: ignore
|
||||
"""
|
||||
[TODO] make this more performant for high-throughput scenario
|
||||
"""
|
||||
try:
|
||||
if self.mock_redacted_text is not None:
|
||||
redacted_text = self.mock_redacted_text
|
||||
else:
|
||||
# Make the first request to /analyze
|
||||
analyze_url = f"{self.presidio_analyzer_api_base}analyze"
|
||||
verbose_proxy_logger.debug("Making request to: %s", analyze_url)
|
||||
analyze_payload = {"text": text, "language": "en"}
|
||||
if self.ad_hoc_recognizers is not None:
|
||||
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
|
||||
redacted_text = None
|
||||
|
||||
reponse = await self.async_http_client.post(
|
||||
analyze_url, json=analyze_payload
|
||||
)
|
||||
analyze_results = await reponse.json()
|
||||
|
||||
# Make the second request to /anonymize
|
||||
anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize"
|
||||
verbose_proxy_logger.debug("Making request to: %s", anonymize_url)
|
||||
anonymize_payload = {
|
||||
"text": text,
|
||||
"analyzer_results": analyze_results,
|
||||
}
|
||||
|
||||
response_2 = await self.async_http_client.post(
|
||||
anonymize_url, json=anonymize_payload
|
||||
)
|
||||
redacted_text = await response_2.json()
|
||||
|
||||
new_text = text
|
||||
if redacted_text is not None:
|
||||
verbose_proxy_logger.debug("redacted_text: %s", redacted_text)
|
||||
for item in redacted_text["items"]:
|
||||
start = item["start"]
|
||||
end = item["end"]
|
||||
replacement = item["text"] # replacement token
|
||||
if item["operator"] == "replace" and output_parse_pii == True:
|
||||
# check if token in dict
|
||||
# if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing
|
||||
if replacement in self.pii_tokens:
|
||||
replacement = replacement + str(uuid.uuid4())
|
||||
|
||||
self.pii_tokens[replacement] = new_text[
|
||||
start:end
|
||||
] # get text it'll replace
|
||||
|
||||
new_text = new_text[:start] + replacement + new_text[end:]
|
||||
return redacted_text["text"]
|
||||
else:
|
||||
raise Exception(f"Invalid anonymizer response: {redacted_text}")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
"""
|
||||
- Check if request turned off pii
|
||||
- Check if user allowed to turn off pii (key permissions -> 'allow_pii_controls')
|
||||
|
||||
- Take the request data
|
||||
- Call /analyze -> get the results
|
||||
- Call /anonymize w/ the analyze results -> get the redacted text
|
||||
|
||||
For multiple messages in /chat/completions, we'll need to call them in parallel.
|
||||
"""
|
||||
try:
|
||||
if (
|
||||
self.logging_only is True
|
||||
): # only modify the logging obj data (done by async_logging_hook)
|
||||
return data
|
||||
permissions = user_api_key_dict.permissions
|
||||
output_parse_pii = permissions.get(
|
||||
"output_parse_pii", litellm.output_parse_pii
|
||||
) # allow key to turn on/off output parsing for pii
|
||||
no_pii = permissions.get(
|
||||
"no-pii", None
|
||||
) # allow key to turn on/off pii masking (if user is allowed to set pii controls, then they can override the key defaults)
|
||||
|
||||
if no_pii is None:
|
||||
# check older way of turning on/off pii
|
||||
no_pii = not permissions.get("pii", True)
|
||||
|
||||
content_safety = data.get("content_safety", None)
|
||||
verbose_proxy_logger.debug("content_safety: %s", content_safety)
|
||||
## Request-level turn on/off PII controls ##
|
||||
if content_safety is not None and isinstance(content_safety, dict):
|
||||
# pii masking ##
|
||||
if (
|
||||
content_safety.get("no-pii", None) is not None
|
||||
and content_safety.get("no-pii") == True
|
||||
):
|
||||
# check if user allowed to turn this off
|
||||
if permissions.get("allow_pii_controls", False) == False:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Not allowed to set PII controls per request"
|
||||
},
|
||||
)
|
||||
else: # user allowed to turn off pii masking
|
||||
no_pii = content_safety.get("no-pii")
|
||||
if not isinstance(no_pii, bool):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "no_pii needs to be a boolean value"},
|
||||
)
|
||||
## pii output parsing ##
|
||||
if content_safety.get("output_parse_pii", None) is not None:
|
||||
# check if user allowed to turn this off
|
||||
if permissions.get("allow_pii_controls", False) == False:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Not allowed to set PII controls per request"
|
||||
},
|
||||
)
|
||||
else: # user allowed to turn on/off pii output parsing
|
||||
output_parse_pii = content_safety.get("output_parse_pii")
|
||||
if not isinstance(output_parse_pii, bool):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "output_parse_pii needs to be a boolean value"
|
||||
},
|
||||
)
|
||||
|
||||
if no_pii is True: # turn off pii masking
|
||||
return data
|
||||
|
||||
if call_type == "completion": # /chat/completions requests
|
||||
messages = data["messages"]
|
||||
tasks = []
|
||||
|
||||
for m in messages:
|
||||
if isinstance(m["content"], str):
|
||||
tasks.append(
|
||||
self.check_pii(
|
||||
text=m["content"], output_parse_pii=output_parse_pii
|
||||
)
|
||||
)
|
||||
responses = await asyncio.gather(*tasks)
|
||||
for index, r in enumerate(responses):
|
||||
if isinstance(messages[index]["content"], str):
|
||||
messages[index][
|
||||
"content"
|
||||
] = r # replace content with redacted string
|
||||
verbose_proxy_logger.info(
|
||||
f"Presidio PII Masking: Redacted pii message: {data['messages']}"
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
f"An error occurred -",
|
||||
)
|
||||
raise e
|
||||
|
||||
async def async_logging_hook(
|
||||
self, kwargs: dict, result: Any, call_type: str
|
||||
) -> Tuple[dict, Any]:
|
||||
"""
|
||||
Masks the input before logging to langfuse, datadog, etc.
|
||||
"""
|
||||
if (
|
||||
call_type == "completion" or call_type == "acompletion"
|
||||
): # /chat/completions requests
|
||||
messages: Optional[List] = kwargs.get("messages", None)
|
||||
tasks = []
|
||||
|
||||
if messages is None:
|
||||
return kwargs, result
|
||||
|
||||
for m in messages:
|
||||
text_str = ""
|
||||
if m["content"] is None:
|
||||
continue
|
||||
if isinstance(m["content"], str):
|
||||
text_str = m["content"]
|
||||
tasks.append(
|
||||
self.check_pii(text=text_str, output_parse_pii=False)
|
||||
) # need to pass separately b/c presidio has context window limits
|
||||
responses = await asyncio.gather(*tasks)
|
||||
for index, r in enumerate(responses):
|
||||
if isinstance(messages[index]["content"], str):
|
||||
messages[index][
|
||||
"content"
|
||||
] = r # replace content with redacted string
|
||||
verbose_proxy_logger.info(
|
||||
f"Presidio PII Masking: Redacted pii message: {messages}"
|
||||
)
|
||||
kwargs["messages"] = messages
|
||||
|
||||
return kwargs, responses
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
|
||||
):
|
||||
"""
|
||||
Output parse the response object to replace the masked tokens with user sent values
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
f"PII Masking Args: litellm.output_parse_pii={litellm.output_parse_pii}; type of response={type(response)}"
|
||||
)
|
||||
if litellm.output_parse_pii == False:
|
||||
return response
|
||||
|
||||
if isinstance(response, ModelResponse) and not isinstance(
|
||||
response.choices[0], StreamingChoices
|
||||
): # /chat/completions requests
|
||||
if isinstance(response.choices[0].message.content, str):
|
||||
verbose_proxy_logger.debug(
|
||||
f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}"
|
||||
)
|
||||
for key, value in self.pii_tokens.items():
|
||||
response.choices[0].message.content = response.choices[
|
||||
0
|
||||
].message.content.replace(key, value)
|
||||
return response
|
|
@ -165,6 +165,16 @@ def init_guardrails_v2(
|
|||
category_thresholds=litellm_params.get("category_thresholds"),
|
||||
)
|
||||
litellm.callbacks.append(_lakera_callback) # type: ignore
|
||||
elif litellm_params["guardrail"] == "presidio":
|
||||
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
|
||||
_OPTIONAL_PresidioPIIMasking,
|
||||
)
|
||||
|
||||
_presidio_callback = _OPTIONAL_PresidioPIIMasking(
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=litellm_params["mode"],
|
||||
)
|
||||
litellm.callbacks.append(_presidio_callback) # type: ignore
|
||||
elif (
|
||||
isinstance(litellm_params["guardrail"], str)
|
||||
and "." in litellm_params["guardrail"]
|
||||
|
|
|
@ -15,6 +15,12 @@ litellm_settings:
|
|||
success_callback: ["prometheus"]
|
||||
failure_callback: ["prometheus"]
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "presidio"
|
||||
litellm_params:
|
||||
guardrail: presidio # supported values: "aporia", "lakera", "presidio"
|
||||
mode: "pre_call" # pre_call, during_call, post_call
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue