Merge pull request #1970 from BerriAI/litellm_fix_pii_output_parsing

feat(presidio_pii_masking.py): enable output parsing for pii masking
This commit is contained in:
Krish Dholakia 2024-02-13 22:36:25 -08:00 committed by GitHub
commit f9dbd74a2c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 177 additions and 26 deletions

View file

@ -164,6 +164,8 @@ secret_manager_client: Optional[
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
_google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
#### PII MASKING ####
output_parse_pii: bool = False
#############################################

View file

@ -2,9 +2,11 @@
# On success, logs events to Promptlayer
import dotenv, os
import requests
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal
from typing import Literal, Union
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
@ -54,7 +56,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal["completion", "embeddings"],
call_type: Literal["completion", "embeddings", "image_generation"],
):
pass
@ -63,21 +65,11 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
):
pass
async def async_post_call_streaming_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
"""
Returns streaming chunk before their returned to user
"""
pass
async def async_post_call_success_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
self,
user_api_key_dict: UserAPIKeyAuth,
response,
):
"""
Returns llm response before it's returned to user
"""
pass
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function

View file

@ -8,14 +8,19 @@
# Tell us how we can improve! - Krrish & Ishaan
from typing import Optional
import litellm, traceback, sys
from typing import Optional, Literal, Union
import litellm, traceback, sys, uuid
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm import ModelResponse
from litellm.utils import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
StreamingChoices,
)
from datetime import datetime
import aiohttp, asyncio
@ -24,7 +29,13 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
user_api_key_cache = None
# Class variables or attributes
def __init__(self):
def __init__(self, mock_testing: bool = False):
self.pii_tokens: dict = (
{}
) # mapping of PII token to original text - only used with Presidio `replace` operation
if mock_testing == True: # for testing purposes only
return
self.presidio_analyzer_api_base = litellm.get_secret(
"PRESIDIO_ANALYZER_API_BASE", None
)
@ -51,12 +62,15 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
pass
async def check_pii(self, text: str) -> str:
"""
[TODO] make this more performant for high-throughput scenario
"""
try:
async with aiohttp.ClientSession() as session:
# Make the first request to /analyze
analyze_url = f"{self.presidio_analyzer_api_base}/analyze"
analyze_payload = {"text": text, "language": "en"}
redacted_text = None
async with session.post(analyze_url, json=analyze_payload) as response:
analyze_results = await response.json()
@ -72,6 +86,26 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
) as response:
redacted_text = await response.json()
new_text = text
if redacted_text is not None:
for item in redacted_text["items"]:
start = item["start"]
end = item["end"]
replacement = item["text"] # replacement token
if (
item["operator"] == "replace"
and litellm.output_parse_pii == True
):
# check if token in dict
# if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing
if replacement in self.pii_tokens:
replacement = replacement + uuid.uuid4()
self.pii_tokens[replacement] = new_text[
start:end
] # get text it'll replace
new_text = new_text[:start] + replacement + new_text[end:]
return redacted_text["text"]
except Exception as e:
traceback.print_exc()
@ -94,6 +128,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
if call_type == "completion": # /chat/completions requests
messages = data["messages"]
tasks = []
for m in messages:
if isinstance(m["content"], str):
tasks.append(self.check_pii(text=m["content"]))
@ -104,3 +139,30 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
"content"
] = r # replace content with redacted string
return data
async def async_post_call_success_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
):
"""
Output parse the response object to replace the masked tokens with user sent values
"""
verbose_proxy_logger.debug(
f"PII Masking Args: litellm.output_parse_pii={litellm.output_parse_pii}; type of response={type(response)}"
)
if litellm.output_parse_pii == False:
return response
if isinstance(response, ModelResponse) and not isinstance(
response.choices[0], StreamingChoices
): # /chat/completions requests
if isinstance(response.choices[0].message.content, str):
verbose_proxy_logger.debug(
f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}"
)
for key, value in self.pii_tokens.items():
response.choices[0].message.content = response.choices[
0
].message.content.replace(key, value)
return response

View file

@ -166,9 +166,9 @@ class ProxyException(Exception):
async def openai_exception_handler(request: Request, exc: ProxyException):
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
return JSONResponse(
status_code=int(exc.code)
if exc.code
else status.HTTP_500_INTERNAL_SERVER_ERROR,
status_code=(
int(exc.code) if exc.code else status.HTTP_500_INTERNAL_SERVER_ERROR
),
content={
"error": {
"message": exc.message,
@ -2428,6 +2428,11 @@ async def chat_completion(
)
fastapi_response.headers["x-litellm-model-id"] = model_id
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response
)
return response
except Exception as e:
traceback.print_exc()
@ -4553,9 +4558,11 @@ async def get_routes():
"path": getattr(route, "path", None),
"methods": getattr(route, "methods", None),
"name": getattr(route, "name", None),
"endpoint": getattr(route, "endpoint", None).__name__
if getattr(route, "endpoint", None)
else None,
"endpoint": (
getattr(route, "endpoint", None).__name__
if getattr(route, "endpoint", None)
else None
),
}
routes.append(route_info)

View file

@ -11,6 +11,7 @@ from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
)
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.integrations.custom_logger import CustomLogger
@ -377,6 +378,28 @@ class ProxyLogging:
raise e
return
async def post_call_success_hook(
self,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
user_api_key_dict: UserAPIKeyAuth,
):
"""
Allow user to modify outgoing data
Covers:
1. /chat/completions
"""
new_response = copy.deepcopy(response)
for callback in litellm.callbacks:
try:
if isinstance(callback, CustomLogger):
await callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=new_response
)
except Exception as e:
raise e
return new_response
### DB CONNECTOR ###
# Define the retry decorator with backoff strategy

View file

@ -0,0 +1,65 @@
# What is this?
## Unit test for presidio pii masking
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm.proxy.hooks.presidio_pii_masking import _OPTIONAL_PresidioPIIMasking
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
@pytest.mark.asyncio
async def test_output_parsing():
"""
- have presidio pii masking - mask an input message
- make llm completion call
- have presidio pii masking - output parse message
- assert that no masked tokens are in the input message
"""
litellm.output_parse_pii = True
pii_masking = _OPTIONAL_PresidioPIIMasking(mock_testing=True)
initial_message = [
{
"role": "user",
"content": "hello world, my name is Jane Doe. My number is: 034453334",
}
]
filtered_message = [
{
"role": "user",
"content": "hello world, my name is <PERSON>. My number is: <PHONE_NUMBER>",
}
]
pii_masking.pii_tokens = {"<PERSON>": "Jane Doe", "<PHONE_NUMBER>": "034453334"}
response = mock_completion(
model="gpt-3.5-turbo",
messages=filtered_message,
mock_response="Hello <PERSON>! How can I assist you today?",
)
new_response = await pii_masking.async_post_call_success_hook(
user_api_key_dict=UserAPIKeyAuth(), response=response
)
assert (
new_response.choices[0].message.content
== "Hello Jane Doe! How can I assist you today?"
)
# asyncio.run(test_output_parsing())