feat(presidio_pii_masking.py): enable output parsing for pii masking

This commit is contained in:
Krrish Dholakia 2024-02-13 21:36:57 -08:00
parent f09c09ace4
commit f68b656040
6 changed files with 181 additions and 13 deletions

View file

@ -8,14 +8,19 @@
# Tell us how we can improve! - Krrish & Ishaan
from typing import Optional
import litellm, traceback, sys
from typing import Optional, Literal, Union
import litellm, traceback, sys, uuid
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm import ModelResponse
from litellm.utils import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
StreamingChoices,
)
from datetime import datetime
import aiohttp, asyncio
@ -24,7 +29,13 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
user_api_key_cache = None
# Class variables or attributes
def __init__(self):
def __init__(self, mock_testing: bool = False):
self.pii_tokens: dict = (
{}
) # mapping of PII token to original text - only used with Presidio `replace` operation
if mock_testing == True: # for testing purposes only
return
self.presidio_analyzer_api_base = litellm.get_secret(
"PRESIDIO_ANALYZER_API_BASE", None
)
@ -51,12 +62,15 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
pass
async def check_pii(self, text: str) -> str:
"""
[TODO] make this more performant for high-throughput scenario
"""
try:
async with aiohttp.ClientSession() as session:
# Make the first request to /analyze
analyze_url = f"{self.presidio_analyzer_api_base}/analyze"
analyze_payload = {"text": text, "language": "en"}
redacted_text = None
async with session.post(analyze_url, json=analyze_payload) as response:
analyze_results = await response.json()
@ -72,6 +86,26 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
) as response:
redacted_text = await response.json()
new_text = text
if redacted_text is not None:
for item in redacted_text["items"]:
start = item["start"]
end = item["end"]
replacement = item["text"] # replacement token
if (
item["operator"] == "replace"
and litellm.output_parse_pii == True
):
# check if token in dict
# if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing
if replacement in self.pii_tokens:
replacement = replacement + uuid.uuid4()
self.pii_tokens[replacement] = new_text[
start:end
] # get text it'll replace
new_text = new_text[:start] + replacement + new_text[end:]
return redacted_text["text"]
except Exception as e:
traceback.print_exc()
@ -94,6 +128,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
if call_type == "completion": # /chat/completions requests
messages = data["messages"]
tasks = []
for m in messages:
if isinstance(m["content"], str):
tasks.append(self.check_pii(text=m["content"]))
@ -104,3 +139,30 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
"content"
] = r # replace content with redacted string
return data
async def async_post_call_success_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
):
"""
Output parse the response object to replace the masked tokens with user sent values
"""
verbose_proxy_logger.debug(
f"PII Masking Args: litellm.output_parse_pii={litellm.output_parse_pii}; type of response={type(response)}"
)
if litellm.output_parse_pii == False:
return response
if isinstance(response, ModelResponse) and isinstance(
response.choices, StreamingChoices
): # /chat/completions requests
if isinstance(response.choices[0].message.content, str):
verbose_proxy_logger.debug(
f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}"
)
for key, value in self.pii_tokens.items():
response.choices[0].message.content = response.choices[
0
].message.content.replace(key, value)
return response