diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index a33803505..935431c83 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -90,9 +90,11 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac ): pass - async def async_logging_hook(self): - """For masking logged request/response""" - pass + async def async_logging_hook( + self, kwargs: dict, result: Any, call_type: str + ) -> Tuple[dict, Any]: + """For masking logged request/response. Return a modified version of the request/result.""" + return kwargs, result def logging_hook( self, kwargs: dict, result: Any, call_type: str diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index fde907ffe..559ed7efd 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -1310,6 +1310,18 @@ class Logging: result=result, litellm_logging_obj=self ) + ## LOGGING HOOK ## + + for callback in callbacks: + if isinstance(callback, CustomLogger): + self.model_call_details["input"], result = ( + await callback.async_logging_hook( + kwargs=self.model_call_details, + result=result, + call_type=self.call_type, + ) + ) + for callback in callbacks: # check if callback can run for this request litellm_params = self.model_call_details.get("litellm_params", {}) diff --git a/litellm/proxy/common_utils/init_callbacks.py b/litellm/proxy/common_utils/init_callbacks.py index 9631bcad4..385a45145 100644 --- a/litellm/proxy/common_utils/init_callbacks.py +++ b/litellm/proxy/common_utils/init_callbacks.py @@ -42,7 +42,17 @@ def initialize_callbacks_on_proxy( _OPTIONAL_PresidioPIIMasking, ) - pii_masking_object = _OPTIONAL_PresidioPIIMasking() + presidio_logging_only: Optional[bool] = litellm_settings.get( + "presidio_logging_only", None + ) + if presidio_logging_only is not None: + presidio_logging_only = bool( + presidio_logging_only + ) # validate boolean given + + pii_masking_object = _OPTIONAL_PresidioPIIMasking( + logging_only=presidio_logging_only + ) imported_list.append(pii_masking_object) elif isinstance(callback, str) and callback == "llamaguard_moderations": from enterprise.enterprise_hooks.llama_guard import ( diff --git a/litellm/proxy/hooks/presidio_pii_masking.py b/litellm/proxy/hooks/presidio_pii_masking.py index 207d024e9..933d92550 100644 --- a/litellm/proxy/hooks/presidio_pii_masking.py +++ b/litellm/proxy/hooks/presidio_pii_masking.py @@ -12,7 +12,7 @@ import asyncio import json import traceback import uuid -from typing import Optional, Union +from typing import Any, List, Optional, Tuple, Union import aiohttp from fastapi import HTTPException @@ -27,6 +27,7 @@ from litellm.utils import ( ImageResponse, ModelResponse, StreamingChoices, + get_formatted_prompt, ) @@ -36,14 +37,18 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): # Class variables or attributes def __init__( - self, mock_testing: bool = False, mock_redacted_text: Optional[dict] = None + self, + logging_only: Optional[bool] = None, + mock_testing: bool = False, + mock_redacted_text: Optional[dict] = None, ): self.pii_tokens: dict = ( {} ) # mapping of PII token to original text - only used with Presidio `replace` operation self.mock_redacted_text = mock_redacted_text - if mock_testing == True: # for testing purposes only + self.logging_only = logging_only + if mock_testing is True: # for testing purposes only return ad_hoc_recognizers = litellm.presidio_ad_hoc_recognizers @@ -188,6 +193,10 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): For multiple messages in /chat/completions, we'll need to call them in parallel. """ try: + if ( + self.logging_only is True + ): # only modify the logging obj data (done by async_logging_hook) + return data permissions = user_api_key_dict.permissions output_parse_pii = permissions.get( "output_parse_pii", litellm.output_parse_pii @@ -244,7 +253,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): }, ) - if no_pii == True: # turn off pii masking + if no_pii is True: # turn off pii masking return data if call_type == "completion": # /chat/completions requests @@ -274,6 +283,43 @@ class _OPTIONAL_PresidioPIIMasking(CustomLogger): ) raise e + async def async_logging_hook( + self, kwargs: dict, result: Any, call_type: str + ) -> Tuple[dict, Any]: + """ + Masks the input before logging to langfuse, datadog, etc. + """ + if ( + call_type == "completion" or call_type == "acompletion" + ): # /chat/completions requests + messages: Optional[List] = kwargs.get("messages", None) + tasks = [] + + if messages is None: + return kwargs, result + + for m in messages: + text_str = "" + if m["content"] is None: + continue + if isinstance(m["content"], str): + text_str = m["content"] + tasks.append( + self.check_pii(text=text_str, output_parse_pii=False) + ) # need to pass separately b/c presidio has context window limits + responses = await asyncio.gather(*tasks) + for index, r in enumerate(responses): + if isinstance(messages[index]["content"], str): + messages[index][ + "content" + ] = r # replace content with redacted string + verbose_proxy_logger.info( + f"Presidio PII Masking: Redacted pii message: {messages}" + ) + kwargs["messages"] = messages + + return kwargs, responses + async def async_post_call_success_hook( self, user_api_key_dict: UserAPIKeyAuth, diff --git a/litellm/tests/test_presidio_masking.py b/litellm/tests/test_presidio_masking.py index 382644016..b4d24bfbe 100644 --- a/litellm/tests/test_presidio_masking.py +++ b/litellm/tests/test_presidio_masking.py @@ -16,6 +16,8 @@ import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +from unittest.mock import AsyncMock, MagicMock, patch + import pytest import litellm @@ -196,3 +198,68 @@ async def test_presidio_pii_masking_input_b(): assert "" in new_data["messages"][0]["content"] assert "" not in new_data["messages"][0]["content"] + + +@pytest.mark.asyncio +async def test_presidio_pii_masking_logging_output_only_no_pre_api_hook(): + pii_masking = _OPTIONAL_PresidioPIIMasking( + logging_only=True, + mock_testing=True, + mock_redacted_text=input_b_anonymizer_results, + ) + + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) + local_cache = DualCache() + + test_messages = [ + { + "role": "user", + "content": "My name is Jane Doe, who are you? Say my name in your response", + } + ] + + new_data = await pii_masking.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={"messages": test_messages}, + call_type="completion", + ) + + assert "Jane Doe" in new_data["messages"][0]["content"] + + +@pytest.mark.asyncio +async def test_presidio_pii_masking_logging_output_only_logged_response(): + pii_masking = _OPTIONAL_PresidioPIIMasking( + logging_only=True, + mock_testing=True, + mock_redacted_text=input_b_anonymizer_results, + ) + + test_messages = [ + { + "role": "user", + "content": "My name is Jane Doe, who are you? Say my name in your response", + } + ] + with patch.object( + pii_masking, "async_log_success_event", new=AsyncMock() + ) as mock_call: + litellm.callbacks = [pii_masking] + response = await litellm.acompletion( + model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!" + ) + + await asyncio.sleep(3) + + assert response.choices[0].message.content == "Hi Peter!" # type: ignore + + mock_call.assert_called_once() + + print(mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"]) + + assert ( + mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"] + == "My name is , who are you? Say my name in your response" + )