feat(guardrails): Flag for PII Masking on Logging

Fixes https://github.com/BerriAI/litellm/issues/4580
This commit is contained in:
Krrish Dholakia 2024-07-11 16:09:34 -07:00
parent ca76d2fd72
commit 9deb9b4e3f
7 changed files with 107 additions and 6 deletions

View file

@ -16,7 +16,7 @@ from litellm._logging import (
log_level, log_level,
) )
from litellm.types.guardrails import GuardrailItem
from litellm.proxy._types import ( from litellm.proxy._types import (
KeyManagementSystem, KeyManagementSystem,
KeyManagementSettings, KeyManagementSettings,
@ -124,6 +124,7 @@ llamaguard_unsafe_content_categories: Optional[str] = None
blocked_user_list: Optional[Union[str, List]] = None blocked_user_list: Optional[Union[str, List]] = None
banned_keywords_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
guardrail_name_config_map: Optional[Dict[str, GuardrailItem]] = None
################## ##################
### PREVIEW FEATURES ### ### PREVIEW FEATURES ###
enable_preview_features: bool = False enable_preview_features: bool = False

View file

@ -2,7 +2,7 @@
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import os import os
import traceback import traceback
from typing import Literal, Optional, Union from typing import Any, Literal, Optional, Tuple, Union
import dotenv import dotenv
@ -90,6 +90,16 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
): ):
pass pass
async def async_logging_hook(self):
"""For masking logged request/response"""
pass
def logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
"""For masking logged request/response. Return a modified version of the request/result."""
return kwargs, result
async def async_moderation_hook( async def async_moderation_hook(
self, self,
data: dict, data: dict,

View file

@ -655,6 +655,16 @@ class Logging:
result=result, litellm_logging_obj=self result=result, litellm_logging_obj=self
) )
## LOGGING HOOK ##
for callback in callbacks:
if isinstance(callback, CustomLogger):
self.model_call_details["input"], result = callback.logging_hook(
kwargs=self.model_call_details,
result=result,
call_type=self.call_type,
)
for callback in callbacks: for callback in callbacks:
try: try:
litellm_params = self.model_call_details.get("litellm_params", {}) litellm_params = self.model_call_details.get("litellm_params", {})

View file

@ -18,7 +18,7 @@ def initialize_guardrails(
premium_user: bool, premium_user: bool,
config_file_path: str, config_file_path: str,
litellm_settings: dict, litellm_settings: dict,
): ) -> Dict[str, GuardrailItem]:
try: try:
verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}") verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}")
global all_guardrails global all_guardrails
@ -55,7 +55,11 @@ def initialize_guardrails(
litellm_settings=litellm_settings, litellm_settings=litellm_settings,
) )
return guardrail_name_config_map
except Exception as e: except Exception as e:
verbose_proxy_logger.error(f"error initializing guardrails {str(e)}") verbose_proxy_logger.error(
traceback.print_exc() "error initializing guardrails {}\n{}".format(
str(e), traceback.format_exc()
)
)
raise e raise e

View file

@ -1467,12 +1467,14 @@ class ProxyConfig:
+ CommonProxyErrors.not_premium_user.value + CommonProxyErrors.not_premium_user.value
) )
initialize_guardrails( guardrail_name_config_map = initialize_guardrails(
guardrails_config=value, guardrails_config=value,
premium_user=premium_user, premium_user=premium_user,
config_file_path=config_file_path, config_file_path=config_file_path,
litellm_settings=litellm_settings, litellm_settings=litellm_settings,
) )
litellm.guardrail_name_config_map = guardrail_name_config_map
elif key == "callbacks": elif key == "callbacks":
initialize_callbacks_on_proxy( initialize_callbacks_on_proxy(

View file

@ -0,0 +1,73 @@
# What is this?
## Unit Tests for guardrails config
import asyncio
import inspect
import os
import sys
import time
import traceback
import uuid
from datetime import datetime
import pytest
from pydantic import BaseModel
import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
sys.path.insert(0, os.path.abspath("../.."))
from typing import Any, List, Literal, Optional, Tuple, Union
from unittest.mock import AsyncMock, MagicMock, patch
import litellm
from litellm import Cache, completion, embedding
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import LiteLLMCommonStrings
class CustomLoggingIntegration(CustomLogger):
def __init__(self) -> None:
super().__init__()
def logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
input: Optional[Any] = kwargs.get("input", None)
messages: Optional[List] = kwargs.get("messages", None)
if call_type == "completion":
# assume input is of type messages
if input is not None and isinstance(input, list):
input[0]["content"] = "Hey, my name is [NAME]."
if messages is not None and isinstance(messages, List):
messages[0]["content"] = "Hey, my name is [NAME]."
kwargs["input"] = input
kwargs["messages"] = messages
return kwargs, result
def test_guardrail_masking_logging_only():
"""
Assert response is unmasked.
Assert logged response is masked.
"""
callback = CustomLoggingIntegration()
with patch.object(callback, "log_success_event", new=MagicMock()) as mock_call:
litellm.callbacks = [callback]
messages = [{"role": "user", "content": "Hey, my name is Peter."}]
response = completion(
model="gpt-3.5-turbo", messages=messages, mock_response="Hi Peter!"
)
assert response.choices[0].message.content == "Hi Peter!" # type: ignore
mock_call.assert_called_once()
print(mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"])
assert (
mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"]
== "Hey, my name is [NAME]."
)

View file

@ -19,4 +19,5 @@ litellm_settings:
class GuardrailItem(BaseModel): class GuardrailItem(BaseModel):
callbacks: List[str] callbacks: List[str]
default_on: bool default_on: bool
logging_only: bool
guardrail_name: str guardrail_name: str