From 2d845b12ed4dffc7f85f57c34a3e76f3d6104a21 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 10 Feb 2024 20:20:59 -0800 Subject: [PATCH] feat(proxy_server.py): support for pii masking with microsoft presidio --- litellm/proxy/hooks/presidio_pii_masking.py | 106 ++++++++++++++++++++ litellm/proxy/proxy_server.py | 23 ++++- litellm/proxy/utils.py | 2 +- 3 files changed, 126 insertions(+), 5 deletions(-) create mode 100644 litellm/proxy/hooks/presidio_pii_masking.py diff --git a/litellm/proxy/hooks/presidio_pii_masking.py b/litellm/proxy/hooks/presidio_pii_masking.py new file mode 100644 index 0000000000..01a0f3dc7f --- /dev/null +++ b/litellm/proxy/hooks/presidio_pii_masking.py @@ -0,0 +1,106 @@ +# +-----------------------------------------------+ +# | | +# | PII Masking | +# | with Microsoft Presidio | +# | https://github.com/BerriAI/litellm/issues/ | +# +-----------------------------------------------+ +# +# Tell us how we can improve! - Krrish & Ishaan + + +from typing import Optional +import litellm, traceback, sys +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from fastapi import HTTPException +from litellm._logging import verbose_proxy_logger +from litellm import ModelResponse +from datetime import datetime +import aiohttp, asyncio + + +class _OPTIONAL_PresidioPIIMasking(CustomLogger): + user_api_key_cache = None + + # Class variables or attributes + def __init__(self): + self.presidio_analyzer_api_base = litellm.get_secret( + "PRESIDIO_ANALYZER_API_BASE", None + ) + self.presidio_anonymizer_api_base = litellm.get_secret( + "PRESIDIO_ANONYMIZER_API_BASE", None + ) + + if self.presidio_analyzer_api_base is None: + raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment") + elif not self.presidio_analyzer_api_base.endswith("/"): + self.presidio_analyzer_api_base += "/" + + if self.presidio_anonymizer_api_base is None: + raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment") + elif not self.presidio_anonymizer_api_base.endswith("/"): + self.presidio_anonymizer_api_base += "/" + + def print_verbose(self, print_statement): + try: + verbose_proxy_logger.debug(print_statement) + if litellm.set_verbose: + print(print_statement) # noqa + except: + pass + + async def check_pii(self, text: str) -> str: + try: + async with aiohttp.ClientSession() as session: + # Make the first request to /analyze + analyze_url = f"{self.presidio_analyzer_api_base}/analyze" + analyze_payload = {"text": text, "language": "en"} + + async with session.post(analyze_url, json=analyze_payload) as response: + analyze_results = await response.json() + + # Make the second request to /anonymize + anonymize_url = f"{self.presidio_anonymizer_api_base}/anonymize" + anonymize_payload = { + "text": "hello world, my name is Jane Doe. My number is: 034453334", + "analyzer_results": analyze_results, + } + + async with session.post( + anonymize_url, json=anonymize_payload + ) as response: + redacted_text = await response.json() + + return redacted_text["text"] + except Exception as e: + traceback.print_exc() + raise e + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ): + """ + - Take the request data + - Call /analyze -> get the results + - Call /anonymize w/ the analyze results -> get the redacted text + + For multiple messages in /chat/completions, we'll need to call them in parallel. + """ + if call_type == "completion": # /chat/completions requests + messages = data["messages"] + tasks = [] + for m in messages: + if isinstance(m["content"], str): + tasks.append(self.check_pii(text=m["content"])) + responses = await asyncio.gather(*tasks) + for index, r in enumerate(responses): + if isinstance(messages[index]["content"], str): + messages[index][ + "content" + ] = r # replace content with redacted string + return data diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 33eaae4725..757dd8aa3e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect import threading, ast import shutil, random, traceback, requests from datetime import datetime, timedelta, timezone -from typing import Optional, List +from typing import Optional, List, Callable import secrets, subprocess import hashlib, uuid import warnings @@ -1293,9 +1293,24 @@ class ProxyConfig: f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" ) elif key == "callbacks": - litellm.callbacks = [ - get_instance_fn(value=value, config_file_path=config_file_path) - ] + if isinstance(value, list): + imported_list = [] + for callback in value: # ["presidio", ] + if isinstance(callback, str) and callback == "presidio": + from litellm.proxy.hooks.presidio_pii_masking import ( + _OPTIONAL_PresidioPIIMasking, + ) + + pii_masking_object = _OPTIONAL_PresidioPIIMasking() + imported_list.append(pii_masking_object) + else: + imported_list.append( + get_instance_fn( + value=callback, + config_file_path=config_file_path, + ) + ) + litellm.callbacks = imported_list # type: ignore verbose_proxy_logger.debug( f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" ) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 35b6472577..4a1eb086d1 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -92,7 +92,7 @@ class ProxyLogging: self, user_api_key_dict: UserAPIKeyAuth, data: dict, - call_type: Literal["completion", "embeddings"], + call_type: Literal["completion", "embeddings", "image_generation"], ): """ Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.