fix(presidio_pii_masking.py): support logging_only pii masking

This commit is contained in:
Krrish Dholakia 2024-07-11 18:04:12 -07:00
parent 9deb9b4e3f
commit 9d918d2ac7
5 changed files with 145 additions and 8 deletions

View file

@ -90,9 +90,11 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
): ):
pass pass
async def async_logging_hook(self): async def async_logging_hook(
"""For masking logged request/response""" self, kwargs: dict, result: Any, call_type: str
pass ) -> Tuple[dict, Any]:
"""For masking logged request/response. Return a modified version of the request/result."""
return kwargs, result
def logging_hook( def logging_hook(
self, kwargs: dict, result: Any, call_type: str self, kwargs: dict, result: Any, call_type: str

View file

@ -1310,6 +1310,18 @@ class Logging:
result=result, litellm_logging_obj=self result=result, litellm_logging_obj=self
) )
## LOGGING HOOK ##
for callback in callbacks:
if isinstance(callback, CustomLogger):
self.model_call_details["input"], result = (
await callback.async_logging_hook(
kwargs=self.model_call_details,
result=result,
call_type=self.call_type,
)
)
for callback in callbacks: for callback in callbacks:
# check if callback can run for this request # check if callback can run for this request
litellm_params = self.model_call_details.get("litellm_params", {}) litellm_params = self.model_call_details.get("litellm_params", {})

View file

@ -42,7 +42,17 @@ def initialize_callbacks_on_proxy(
_OPTIONAL_PresidioPIIMasking, _OPTIONAL_PresidioPIIMasking,
) )
pii_masking_object = _OPTIONAL_PresidioPIIMasking() presidio_logging_only: Optional[bool] = litellm_settings.get(
"presidio_logging_only", None
)
if presidio_logging_only is not None:
presidio_logging_only = bool(
presidio_logging_only
) # validate boolean given
pii_masking_object = _OPTIONAL_PresidioPIIMasking(
logging_only=presidio_logging_only
)
imported_list.append(pii_masking_object) imported_list.append(pii_masking_object)
elif isinstance(callback, str) and callback == "llamaguard_moderations": elif isinstance(callback, str) and callback == "llamaguard_moderations":
from enterprise.enterprise_hooks.llama_guard import ( from enterprise.enterprise_hooks.llama_guard import (

View file

@ -12,7 +12,7 @@ import asyncio
import json import json
import traceback import traceback
import uuid import uuid
from typing import Optional, Union from typing import Any, List, Optional, Tuple, Union
import aiohttp import aiohttp
from fastapi import HTTPException from fastapi import HTTPException
@ -27,6 +27,7 @@ from litellm.utils import (
ImageResponse, ImageResponse,
ModelResponse, ModelResponse,
StreamingChoices, StreamingChoices,
get_formatted_prompt,
) )
@ -36,14 +37,18 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
# Class variables or attributes # Class variables or attributes
def __init__( def __init__(
self, mock_testing: bool = False, mock_redacted_text: Optional[dict] = None self,
logging_only: Optional[bool] = None,
mock_testing: bool = False,
mock_redacted_text: Optional[dict] = None,
): ):
self.pii_tokens: dict = ( self.pii_tokens: dict = (
{} {}
) # mapping of PII token to original text - only used with Presidio `replace` operation ) # mapping of PII token to original text - only used with Presidio `replace` operation
self.mock_redacted_text = mock_redacted_text self.mock_redacted_text = mock_redacted_text
if mock_testing == True: # for testing purposes only self.logging_only = logging_only
if mock_testing is True: # for testing purposes only
return return
ad_hoc_recognizers = litellm.presidio_ad_hoc_recognizers ad_hoc_recognizers = litellm.presidio_ad_hoc_recognizers
@ -188,6 +193,10 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
For multiple messages in /chat/completions, we'll need to call them in parallel. For multiple messages in /chat/completions, we'll need to call them in parallel.
""" """
try: try:
if (
self.logging_only is True
): # only modify the logging obj data (done by async_logging_hook)
return data
permissions = user_api_key_dict.permissions permissions = user_api_key_dict.permissions
output_parse_pii = permissions.get( output_parse_pii = permissions.get(
"output_parse_pii", litellm.output_parse_pii "output_parse_pii", litellm.output_parse_pii
@ -244,7 +253,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
}, },
) )
if no_pii == True: # turn off pii masking if no_pii is True: # turn off pii masking
return data return data
if call_type == "completion": # /chat/completions requests if call_type == "completion": # /chat/completions requests
@ -274,6 +283,43 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger):
) )
raise e raise e
async def async_logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
"""
Masks the input before logging to langfuse, datadog, etc.
"""
if (
call_type == "completion" or call_type == "acompletion"
): # /chat/completions requests
messages: Optional[List] = kwargs.get("messages", None)
tasks = []
if messages is None:
return kwargs, result
for m in messages:
text_str = ""
if m["content"] is None:
continue
if isinstance(m["content"], str):
text_str = m["content"]
tasks.append(
self.check_pii(text=text_str, output_parse_pii=False)
) # need to pass separately b/c presidio has context window limits
responses = await asyncio.gather(*tasks)
for index, r in enumerate(responses):
if isinstance(messages[index]["content"], str):
messages[index][
"content"
] = r # replace content with redacted string
verbose_proxy_logger.info(
f"Presidio PII Masking: Redacted pii message: {messages}"
)
kwargs["messages"] = messages
return kwargs, responses
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,

View file

@ -16,6 +16,8 @@ import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import litellm import litellm
@ -196,3 +198,68 @@ async def test_presidio_pii_masking_input_b():
assert "<PERSON>" in new_data["messages"][0]["content"] assert "<PERSON>" in new_data["messages"][0]["content"]
assert "<PHONE_NUMBER>" not in new_data["messages"][0]["content"] assert "<PHONE_NUMBER>" not in new_data["messages"][0]["content"]
@pytest.mark.asyncio
async def test_presidio_pii_masking_logging_output_only_no_pre_api_hook():
pii_masking = _OPTIONAL_PresidioPIIMasking(
logging_only=True,
mock_testing=True,
mock_redacted_text=input_b_anonymizer_results,
)
_api_key = "sk-12345"
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()
test_messages = [
{
"role": "user",
"content": "My name is Jane Doe, who are you? Say my name in your response",
}
]
new_data = await pii_masking.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={"messages": test_messages},
call_type="completion",
)
assert "Jane Doe" in new_data["messages"][0]["content"]
@pytest.mark.asyncio
async def test_presidio_pii_masking_logging_output_only_logged_response():
pii_masking = _OPTIONAL_PresidioPIIMasking(
logging_only=True,
mock_testing=True,
mock_redacted_text=input_b_anonymizer_results,
)
test_messages = [
{
"role": "user",
"content": "My name is Jane Doe, who are you? Say my name in your response",
}
]
with patch.object(
pii_masking, "async_log_success_event", new=AsyncMock()
) as mock_call:
litellm.callbacks = [pii_masking]
response = await litellm.acompletion(
model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!"
)
await asyncio.sleep(3)
assert response.choices[0].message.content == "Hi Peter!" # type: ignore
mock_call.assert_called_once()
print(mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"])
assert (
mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"]
== "My name is <PERSON>, who are you? Say my name in your response"
)