mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(guardrails): Flag for PII Masking on Logging
Fixes https://github.com/BerriAI/litellm/issues/4580
This commit is contained in:
parent
ca76d2fd72
commit
9deb9b4e3f
7 changed files with 107 additions and 6 deletions
|
@ -16,7 +16,7 @@ from litellm._logging import (
|
||||||
log_level,
|
log_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from litellm.types.guardrails import GuardrailItem
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
KeyManagementSystem,
|
KeyManagementSystem,
|
||||||
KeyManagementSettings,
|
KeyManagementSettings,
|
||||||
|
@ -124,6 +124,7 @@ llamaguard_unsafe_content_categories: Optional[str] = None
|
||||||
blocked_user_list: Optional[Union[str, List]] = None
|
blocked_user_list: Optional[Union[str, List]] = None
|
||||||
banned_keywords_list: Optional[Union[str, List]] = None
|
banned_keywords_list: Optional[Union[str, List]] = None
|
||||||
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
||||||
|
guardrail_name_config_map: Optional[Dict[str, GuardrailItem]] = None
|
||||||
##################
|
##################
|
||||||
### PREVIEW FEATURES ###
|
### PREVIEW FEATURES ###
|
||||||
enable_preview_features: bool = False
|
enable_preview_features: bool = False
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Literal, Optional, Union
|
from typing import Any, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
|
|
||||||
|
@ -90,6 +90,16 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def async_logging_hook(self):
|
||||||
|
"""For masking logged request/response"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def logging_hook(
|
||||||
|
self, kwargs: dict, result: Any, call_type: str
|
||||||
|
) -> Tuple[dict, Any]:
|
||||||
|
"""For masking logged request/response. Return a modified version of the request/result."""
|
||||||
|
return kwargs, result
|
||||||
|
|
||||||
async def async_moderation_hook(
|
async def async_moderation_hook(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
|
|
@ -655,6 +655,16 @@ class Logging:
|
||||||
result=result, litellm_logging_obj=self
|
result=result, litellm_logging_obj=self
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## LOGGING HOOK ##
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
|
if isinstance(callback, CustomLogger):
|
||||||
|
self.model_call_details["input"], result = callback.logging_hook(
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
result=result,
|
||||||
|
call_type=self.call_type,
|
||||||
|
)
|
||||||
|
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
try:
|
try:
|
||||||
litellm_params = self.model_call_details.get("litellm_params", {})
|
litellm_params = self.model_call_details.get("litellm_params", {})
|
||||||
|
|
|
@ -18,7 +18,7 @@ def initialize_guardrails(
|
||||||
premium_user: bool,
|
premium_user: bool,
|
||||||
config_file_path: str,
|
config_file_path: str,
|
||||||
litellm_settings: dict,
|
litellm_settings: dict,
|
||||||
):
|
) -> Dict[str, GuardrailItem]:
|
||||||
try:
|
try:
|
||||||
verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}")
|
verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}")
|
||||||
global all_guardrails
|
global all_guardrails
|
||||||
|
@ -55,7 +55,11 @@ def initialize_guardrails(
|
||||||
litellm_settings=litellm_settings,
|
litellm_settings=litellm_settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return guardrail_name_config_map
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(f"error initializing guardrails {str(e)}")
|
verbose_proxy_logger.error(
|
||||||
traceback.print_exc()
|
"error initializing guardrails {}\n{}".format(
|
||||||
|
str(e), traceback.format_exc()
|
||||||
|
)
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -1467,12 +1467,14 @@ class ProxyConfig:
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
+ CommonProxyErrors.not_premium_user.value
|
||||||
)
|
)
|
||||||
|
|
||||||
initialize_guardrails(
|
guardrail_name_config_map = initialize_guardrails(
|
||||||
guardrails_config=value,
|
guardrails_config=value,
|
||||||
premium_user=premium_user,
|
premium_user=premium_user,
|
||||||
config_file_path=config_file_path,
|
config_file_path=config_file_path,
|
||||||
litellm_settings=litellm_settings,
|
litellm_settings=litellm_settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
litellm.guardrail_name_config_map = guardrail_name_config_map
|
||||||
elif key == "callbacks":
|
elif key == "callbacks":
|
||||||
|
|
||||||
initialize_callbacks_on_proxy(
|
initialize_callbacks_on_proxy(
|
||||||
|
|
73
litellm/tests/test_guardrails_config.py
Normal file
73
litellm/tests/test_guardrails_config.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
# What is this?
|
||||||
|
## Unit Tests for guardrails config
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import litellm.litellm_core_utils
|
||||||
|
import litellm.litellm_core_utils.litellm_logging
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.abspath("../.."))
|
||||||
|
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import Cache, completion, embedding
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.types.utils import LiteLLMCommonStrings
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLoggingIntegration(CustomLogger):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def logging_hook(
|
||||||
|
self, kwargs: dict, result: Any, call_type: str
|
||||||
|
) -> Tuple[dict, Any]:
|
||||||
|
input: Optional[Any] = kwargs.get("input", None)
|
||||||
|
messages: Optional[List] = kwargs.get("messages", None)
|
||||||
|
if call_type == "completion":
|
||||||
|
# assume input is of type messages
|
||||||
|
if input is not None and isinstance(input, list):
|
||||||
|
input[0]["content"] = "Hey, my name is [NAME]."
|
||||||
|
if messages is not None and isinstance(messages, List):
|
||||||
|
messages[0]["content"] = "Hey, my name is [NAME]."
|
||||||
|
|
||||||
|
kwargs["input"] = input
|
||||||
|
kwargs["messages"] = messages
|
||||||
|
return kwargs, result
|
||||||
|
|
||||||
|
|
||||||
|
def test_guardrail_masking_logging_only():
|
||||||
|
"""
|
||||||
|
Assert response is unmasked.
|
||||||
|
|
||||||
|
Assert logged response is masked.
|
||||||
|
"""
|
||||||
|
callback = CustomLoggingIntegration()
|
||||||
|
|
||||||
|
with patch.object(callback, "log_success_event", new=MagicMock()) as mock_call:
|
||||||
|
litellm.callbacks = [callback]
|
||||||
|
messages = [{"role": "user", "content": "Hey, my name is Peter."}]
|
||||||
|
response = completion(
|
||||||
|
model="gpt-3.5-turbo", messages=messages, mock_response="Hi Peter!"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.choices[0].message.content == "Hi Peter!" # type: ignore
|
||||||
|
|
||||||
|
mock_call.assert_called_once()
|
||||||
|
|
||||||
|
print(mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"])
|
||||||
|
|
||||||
|
assert (
|
||||||
|
mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"]
|
||||||
|
== "Hey, my name is [NAME]."
|
||||||
|
)
|
|
@ -19,4 +19,5 @@ litellm_settings:
|
||||||
class GuardrailItem(BaseModel):
|
class GuardrailItem(BaseModel):
|
||||||
callbacks: List[str]
|
callbacks: List[str]
|
||||||
default_on: bool
|
default_on: bool
|
||||||
|
logging_only: bool
|
||||||
guardrail_name: str
|
guardrail_name: str
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue