From f68b65604009d812b48c6f8d82fa748e7a7dae9c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 13 Feb 2024 21:36:57 -0800 Subject: [PATCH] feat(presidio_pii_masking.py): enable output parsing for pii masking --- litellm/__init__.py | 2 + litellm/integrations/custom_logger.py | 13 +++- litellm/proxy/hooks/presidio_pii_masking.py | 72 +++++++++++++++++++-- litellm/proxy/proxy_server.py | 19 ++++-- litellm/proxy/utils.py | 23 +++++++ litellm/tests/test_presidio_masking.py | 65 +++++++++++++++++++ 6 files changed, 181 insertions(+), 13 deletions(-) create mode 100644 litellm/tests/test_presidio_masking.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 6a0cb95ae..3d62e17a3 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -164,6 +164,8 @@ secret_manager_client: Optional[ ] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. _google_kms_resource_name: Optional[str] = None _key_management_system: Optional[KeyManagementSystem] = None +#### PII MASKING #### +output_parse_pii: bool = False ############################################# diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 316e48aed..d0cdd7702 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -2,9 +2,11 @@ # On success, logs events to Promptlayer import dotenv, os import requests + from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache -from typing import Literal + +from typing import Literal, Union dotenv.load_dotenv() # Loading env variables using dotenv import traceback @@ -54,7 +56,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, - call_type: Literal["completion", "embeddings"], + call_type: Literal["completion", "embeddings", "image_generation"], ): pass @@ -63,6 +65,13 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac ): pass + async def async_post_call_success_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + pass + #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function def log_input_event(self, model, messages, kwargs, print_verbose, callback_func): diff --git a/litellm/proxy/hooks/presidio_pii_masking.py b/litellm/proxy/hooks/presidio_pii_masking.py index 01a0f3dc7..25cd3c54c 100644 --- a/litellm/proxy/hooks/presidio_pii_masking.py +++ b/litellm/proxy/hooks/presidio_pii_masking.py @@ -8,14 +8,19 @@ # Tell us how we can improve! - Krrish & Ishaan -from typing import Optional -import litellm, traceback, sys +from typing import Optional, Literal, Union +import litellm, traceback, sys, uuid from litellm.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException from litellm._logging import verbose_proxy_logger -from litellm import ModelResponse +from litellm.utils import ( + ModelResponse, + EmbeddingResponse, + ImageResponse, + StreamingChoices, +) from datetime import datetime import aiohttp, asyncio @@ -24,7 +29,13 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): user_api_key_cache = None # Class variables or attributes - def __init__(self): + def __init__(self, mock_testing: bool = False): + self.pii_tokens: dict = ( + {} + ) # mapping of PII token to original text - only used with Presidio `replace` operation + if mock_testing == True: # for testing purposes only + return + self.presidio_analyzer_api_base = litellm.get_secret( "PRESIDIO_ANALYZER_API_BASE", None ) @@ -51,12 +62,15 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): pass async def check_pii(self, text: str) -> str: + """ + [TODO] make this more performant for high-throughput scenario + """ try: async with aiohttp.ClientSession() as session: # Make the first request to /analyze analyze_url = f"{self.presidio_analyzer_api_base}/analyze" analyze_payload = {"text": text, "language": "en"} - + redacted_text = None async with session.post(analyze_url, json=analyze_payload) as response: analyze_results = await response.json() @@ -72,6 +86,26 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): ) as response: redacted_text = await response.json() + new_text = text + if redacted_text is not None: + for item in redacted_text["items"]: + start = item["start"] + end = item["end"] + replacement = item["text"] # replacement token + if ( + item["operator"] == "replace" + and litellm.output_parse_pii == True + ): + # check if token in dict + # if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing + if replacement in self.pii_tokens: + replacement = replacement + uuid.uuid4() + + self.pii_tokens[replacement] = new_text[ + start:end + ] # get text it'll replace + + new_text = new_text[:start] + replacement + new_text[end:] return redacted_text["text"] except Exception as e: traceback.print_exc() @@ -94,6 +128,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): if call_type == "completion": # /chat/completions requests messages = data["messages"] tasks = [] + for m in messages: if isinstance(m["content"], str): tasks.append(self.check_pii(text=m["content"])) @@ -104,3 +139,30 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): "content" ] = r # replace content with redacted string return data + + async def async_post_call_success_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response: Union[ModelResponse, EmbeddingResponse, ImageResponse], + ): + """ + Output parse the response object to replace the masked tokens with user sent values + """ + verbose_proxy_logger.debug( + f"PII Masking Args: litellm.output_parse_pii={litellm.output_parse_pii}; type of response={type(response)}" + ) + if litellm.output_parse_pii == False: + return response + + if isinstance(response, ModelResponse) and isinstance( + response.choices, StreamingChoices + ): # /chat/completions requests + if isinstance(response.choices[0].message.content, str): + verbose_proxy_logger.debug( + f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}" + ) + for key, value in self.pii_tokens.items(): + response.choices[0].message.content = response.choices[ + 0 + ].message.content.replace(key, value) + return response diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 621ef08e4..d1cb12f48 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -166,9 +166,9 @@ class ProxyException(Exception): async def openai_exception_handler(request: Request, exc: ProxyException): # NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions return JSONResponse( - status_code=int(exc.code) - if exc.code - else status.HTTP_500_INTERNAL_SERVER_ERROR, + status_code=( + int(exc.code) if exc.code else status.HTTP_500_INTERNAL_SERVER_ERROR + ), content={ "error": { "message": exc.message, @@ -2410,6 +2410,11 @@ async def chat_completion( ) fastapi_response.headers["x-litellm-model-id"] = model_id + + ### CALL HOOKS ### - modify outgoing data + response = await proxy_logging_obj.post_call_success_hook( + user_api_key_dict=user_api_key_dict, response=response + ) return response except Exception as e: traceback.print_exc() @@ -4535,9 +4540,11 @@ async def get_routes(): "path": getattr(route, "path", None), "methods": getattr(route, "methods", None), "name": getattr(route, "name", None), - "endpoint": getattr(route, "endpoint", None).__name__ - if getattr(route, "endpoint", None) - else None, + "endpoint": ( + getattr(route, "endpoint", None).__name__ + if getattr(route, "endpoint", None) + else None + ), } routes.append(route_info) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8741e8a77..0350d54bd 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -11,6 +11,7 @@ from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import ( _PROXY_MaxParallelRequestsHandler, ) +from litellm import ModelResponse, EmbeddingResponse, ImageResponse from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck from litellm.integrations.custom_logger import CustomLogger @@ -377,6 +378,28 @@ class ProxyLogging: raise e return + async def post_call_success_hook( + self, + response: Union[ModelResponse, EmbeddingResponse, ImageResponse], + user_api_key_dict: UserAPIKeyAuth, + ): + """ + Allow user to modify outgoing data + + Covers: + 1. /chat/completions + """ + new_response = copy.deepcopy(response) + for callback in litellm.callbacks: + try: + if isinstance(callback, CustomLogger): + await callback.async_post_call_success_hook( + user_api_key_dict=user_api_key_dict, response=new_response + ) + except Exception as e: + raise e + return new_response + ### DB CONNECTOR ### # Define the retry decorator with backoff strategy diff --git a/litellm/tests/test_presidio_masking.py b/litellm/tests/test_presidio_masking.py new file mode 100644 index 000000000..2275e78e9 --- /dev/null +++ b/litellm/tests/test_presidio_masking.py @@ -0,0 +1,65 @@ +# What is this? +## Unit test for presidio pii masking +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm.proxy.hooks.presidio_pii_masking import _OPTIONAL_PresidioPIIMasking +from litellm import Router, mock_completion +from litellm.proxy.utils import ProxyLogging +from litellm.proxy._types import UserAPIKeyAuth +from litellm.caching import DualCache + + +@pytest.mark.asyncio +async def test_output_parsing(): + """ + - have presidio pii masking - mask an input message + - make llm completion call + - have presidio pii masking - output parse message + - assert that no masked tokens are in the input message + """ + litellm.output_parse_pii = True + pii_masking = _OPTIONAL_PresidioPIIMasking(mock_testing=True) + + initial_message = [ + { + "role": "user", + "content": "hello world, my name is Jane Doe. My number is: 034453334", + } + ] + + filtered_message = [ + { + "role": "user", + "content": "hello world, my name is . My number is: ", + } + ] + + pii_masking.pii_tokens = {"": "Jane Doe", "": "034453334"} + + response = mock_completion( + model="gpt-3.5-turbo", + messages=filtered_message, + mock_response="Hello ! How can I assist you today?", + ) + new_response = await pii_masking.async_post_call_success_hook( + user_api_key_dict=UserAPIKeyAuth(), response=response, call_type="completion" + ) + + assert ( + new_response.choices[0].message.content + == "Hello Jane Doe! How can I assist you today?" + ) + + +# asyncio.run(test_output_parsing())