Merge pull request #1931 from BerriAI/litellm_microsoft_presidio_pii

feat(proxy_server.py): support for pii masking with microsoft presidio
This commit is contained in:
Krish Dholakia 2024-02-11 00:27:14 -08:00 committed by GitHub
commit 1391490d92
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 133 additions and 5 deletions

View file

@ -0,0 +1,106 @@
# +-----------------------------------------------+
# | |
# | PII Masking |
# | with Microsoft Presidio |
# | https://github.com/BerriAI/litellm/issues/ |
# +-----------------------------------------------+
#
# Tell us how we can improve! - Krrish & Ishaan
from typing import Optional
import litellm, traceback, sys
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm import ModelResponse
from datetime import datetime
import aiohttp, asyncio
class _OPTIONAL_PresidioPIIMasking(CustomLogger):
user_api_key_cache = None
# Class variables or attributes
def __init__(self):
self.presidio_analyzer_api_base = litellm.get_secret(
"PRESIDIO_ANALYZER_API_BASE", None
)
self.presidio_anonymizer_api_base = litellm.get_secret(
"PRESIDIO_ANONYMIZER_API_BASE", None
)
if self.presidio_analyzer_api_base is None:
raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment")
elif not self.presidio_analyzer_api_base.endswith("/"):
self.presidio_analyzer_api_base += "/"
if self.presidio_anonymizer_api_base is None:
raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment")
elif not self.presidio_anonymizer_api_base.endswith("/"):
self.presidio_anonymizer_api_base += "/"
def print_verbose(self, print_statement):
try:
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except:
pass
async def check_pii(self, text: str) -> str:
try:
async with aiohttp.ClientSession() as session:
# Make the first request to /analyze
analyze_url = f"{self.presidio_analyzer_api_base}/analyze"
analyze_payload = {"text": text, "language": "en"}
async with session.post(analyze_url, json=analyze_payload) as response:
analyze_results = await response.json()
# Make the second request to /anonymize
anonymize_url = f"{self.presidio_anonymizer_api_base}/anonymize"
anonymize_payload = {
"text": "hello world, my name is Jane Doe. My number is: 034453334",
"analyzer_results": analyze_results,
}
async with session.post(
anonymize_url, json=anonymize_payload
) as response:
redacted_text = await response.json()
return redacted_text["text"]
except Exception as e:
traceback.print_exc()
raise e
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
"""
- Take the request data
- Call /analyze -> get the results
- Call /anonymize w/ the analyze results -> get the redacted text
For multiple messages in /chat/completions, we'll need to call them in parallel.
"""
if call_type == "completion": # /chat/completions requests
messages = data["messages"]
tasks = []
for m in messages:
if isinstance(m["content"], str):
tasks.append(self.check_pii(text=m["content"]))
responses = await asyncio.gather(*tasks)
for index, r in enumerate(responses):
if isinstance(messages[index]["content"], str):
messages[index][
"content"
] = r # replace content with redacted string
return data

View file

@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast import threading, ast
import shutil, random, traceback, requests import shutil, random, traceback, requests
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Optional, List from typing import Optional, List, Callable
import secrets, subprocess import secrets, subprocess
import hashlib, uuid import hashlib, uuid
import warnings import warnings
@ -1294,9 +1294,31 @@ class ProxyConfig:
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
) )
elif key == "callbacks": elif key == "callbacks":
litellm.callbacks = [ if isinstance(value, list):
get_instance_fn(value=value, config_file_path=config_file_path) imported_list = []
] for callback in value: # ["presidio", <my-custom-callback>]
if isinstance(callback, str) and callback == "presidio":
from litellm.proxy.hooks.presidio_pii_masking import (
_OPTIONAL_PresidioPIIMasking,
)
pii_masking_object = _OPTIONAL_PresidioPIIMasking()
imported_list.append(pii_masking_object)
else:
imported_list.append(
get_instance_fn(
value=callback,
config_file_path=config_file_path,
)
)
litellm.callbacks = imported_list # type: ignore
else:
litellm.callbacks = [
get_instance_fn(
value=value,
config_file_path=config_file_path,
)
]
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
) )

View file

@ -92,7 +92,7 @@ class ProxyLogging:
self, self,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
data: dict, data: dict,
call_type: Literal["completion", "embeddings"], call_type: Literal["completion", "embeddings", "image_generation"],
): ):
""" """
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.