From 9deb9b4e3f1bcc151fc1f40cf9d1c773ddc7ca96 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 11 Jul 2024 16:09:34 -0700 Subject: [PATCH] feat(guardrails): Flag for PII Masking on Logging Fixes https://github.com/BerriAI/litellm/issues/4580 --- litellm/__init__.py | 3 +- litellm/integrations/custom_logger.py | 12 ++- litellm/litellm_core_utils/litellm_logging.py | 10 +++ litellm/proxy/guardrails/init_guardrails.py | 10 ++- litellm/proxy/proxy_server.py | 4 +- litellm/tests/test_guardrails_config.py | 73 +++++++++++++++++++ litellm/types/guardrails.py | 1 + 7 files changed, 107 insertions(+), 6 deletions(-) create mode 100644 litellm/tests/test_guardrails_config.py diff --git a/litellm/__init__.py b/litellm/__init__.py index f9e7106090..43ea2d2503 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -16,7 +16,7 @@ from litellm._logging import ( log_level, ) - +from litellm.types.guardrails import GuardrailItem from litellm.proxy._types import ( KeyManagementSystem, KeyManagementSettings, @@ -124,6 +124,7 @@ llamaguard_unsafe_content_categories: Optional[str] = None blocked_user_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" +guardrail_name_config_map: Optional[Dict[str, GuardrailItem]] = None ################## ### PREVIEW FEATURES ### enable_preview_features: bool = False diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index da9826b9b5..a338035054 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -2,7 +2,7 @@ # On success, logs events to Promptlayer import os import traceback -from typing import Literal, Optional, Union +from typing import Any, Literal, Optional, Tuple, Union import dotenv @@ -90,6 +90,16 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac ): pass + async def async_logging_hook(self): + """For masking logged request/response""" + pass + + def logging_hook( + self, kwargs: dict, result: Any, call_type: str + ) -> Tuple[dict, Any]: + """For masking logged request/response. Return a modified version of the request/result.""" + return kwargs, result + async def async_moderation_hook( self, data: dict, diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 0271c57147..fde907ffef 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -655,6 +655,16 @@ class Logging: result=result, litellm_logging_obj=self ) + ## LOGGING HOOK ## + + for callback in callbacks: + if isinstance(callback, CustomLogger): + self.model_call_details["input"], result = callback.logging_hook( + kwargs=self.model_call_details, + result=result, + call_type=self.call_type, + ) + for callback in callbacks: try: litellm_params = self.model_call_details.get("litellm_params", {}) diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index 9c9fde5337..d171d5b913 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -18,7 +18,7 @@ def initialize_guardrails( premium_user: bool, config_file_path: str, litellm_settings: dict, -): +) -> Dict[str, GuardrailItem]: try: verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}") global all_guardrails @@ -55,7 +55,11 @@ def initialize_guardrails( litellm_settings=litellm_settings, ) + return guardrail_name_config_map except Exception as e: - verbose_proxy_logger.error(f"error initializing guardrails {str(e)}") - traceback.print_exc() + verbose_proxy_logger.error( + "error initializing guardrails {}\n{}".format( + str(e), traceback.format_exc() + ) + ) raise e diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 69f63d9853..193b27e123 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1467,12 +1467,14 @@ class ProxyConfig: + CommonProxyErrors.not_premium_user.value ) - initialize_guardrails( + guardrail_name_config_map = initialize_guardrails( guardrails_config=value, premium_user=premium_user, config_file_path=config_file_path, litellm_settings=litellm_settings, ) + + litellm.guardrail_name_config_map = guardrail_name_config_map elif key == "callbacks": initialize_callbacks_on_proxy( diff --git a/litellm/tests/test_guardrails_config.py b/litellm/tests/test_guardrails_config.py new file mode 100644 index 0000000000..a086c80815 --- /dev/null +++ b/litellm/tests/test_guardrails_config.py @@ -0,0 +1,73 @@ +# What is this? +## Unit Tests for guardrails config +import asyncio +import inspect +import os +import sys +import time +import traceback +import uuid +from datetime import datetime + +import pytest +from pydantic import BaseModel + +import litellm.litellm_core_utils +import litellm.litellm_core_utils.litellm_logging + +sys.path.insert(0, os.path.abspath("../..")) +from typing import Any, List, Literal, Optional, Tuple, Union +from unittest.mock import AsyncMock, MagicMock, patch + +import litellm +from litellm import Cache, completion, embedding +from litellm.integrations.custom_logger import CustomLogger +from litellm.types.utils import LiteLLMCommonStrings + + +class CustomLoggingIntegration(CustomLogger): + def __init__(self) -> None: + super().__init__() + + def logging_hook( + self, kwargs: dict, result: Any, call_type: str + ) -> Tuple[dict, Any]: + input: Optional[Any] = kwargs.get("input", None) + messages: Optional[List] = kwargs.get("messages", None) + if call_type == "completion": + # assume input is of type messages + if input is not None and isinstance(input, list): + input[0]["content"] = "Hey, my name is [NAME]." + if messages is not None and isinstance(messages, List): + messages[0]["content"] = "Hey, my name is [NAME]." + + kwargs["input"] = input + kwargs["messages"] = messages + return kwargs, result + + +def test_guardrail_masking_logging_only(): + """ + Assert response is unmasked. + + Assert logged response is masked. + """ + callback = CustomLoggingIntegration() + + with patch.object(callback, "log_success_event", new=MagicMock()) as mock_call: + litellm.callbacks = [callback] + messages = [{"role": "user", "content": "Hey, my name is Peter."}] + response = completion( + model="gpt-3.5-turbo", messages=messages, mock_response="Hi Peter!" + ) + + assert response.choices[0].message.content == "Hi Peter!" # type: ignore + + mock_call.assert_called_once() + + print(mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"]) + + assert ( + mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"] + == "Hey, my name is [NAME]." + ) diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 7dd06a79b1..3ef20aa47a 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -19,4 +19,5 @@ litellm_settings: class GuardrailItem(BaseModel): callbacks: List[str] default_on: bool + logging_only: bool guardrail_name: str