feat(guardrails.py): allow setting logging_only in guardrails_config for presidio pii masking integration

This commit is contained in:
Krrish Dholakia 2024-07-13 12:22:17 -07:00
parent f2522867ed
commit 6b78e39600
7 changed files with 71 additions and 18 deletions

View file

@ -17,7 +17,6 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy.guardrails.init_guardrails import all_guardrails
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from datetime import datetime from datetime import datetime

View file

@ -125,7 +125,7 @@ llamaguard_unsafe_content_categories: Optional[str] = None
blocked_user_list: Optional[Union[str, List]] = None blocked_user_list: Optional[Union[str, List]] = None
banned_keywords_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
guardrail_name_config_map: Optional[Dict[str, GuardrailItem]] = None guardrail_name_config_map: Dict[str, GuardrailItem] = {}
################## ##################
### PREVIEW FEATURES ### ### PREVIEW FEATURES ###
enable_preview_features: bool = False enable_preview_features: bool = False

View file

@ -14,6 +14,7 @@ def initialize_callbacks_on_proxy(
premium_user: bool, premium_user: bool,
config_file_path: str, config_file_path: str,
litellm_settings: dict, litellm_settings: dict,
callback_specific_params: dict = {},
): ):
from litellm.proxy.proxy_server import prisma_client from litellm.proxy.proxy_server import prisma_client
@ -25,7 +26,6 @@ def initialize_callbacks_on_proxy(
known_compatible_callbacks = list( known_compatible_callbacks = list(
get_args(litellm._custom_logger_compatible_callbacks_literal) get_args(litellm._custom_logger_compatible_callbacks_literal)
) )
for callback in value: # ["presidio", <my-custom-callback>] for callback in value: # ["presidio", <my-custom-callback>]
if isinstance(callback, str) and callback in known_compatible_callbacks: if isinstance(callback, str) and callback in known_compatible_callbacks:
imported_list.append(callback) imported_list.append(callback)
@ -54,9 +54,11 @@ def initialize_callbacks_on_proxy(
presidio_logging_only presidio_logging_only
) # validate boolean given ) # validate boolean given
pii_masking_object = _OPTIONAL_PresidioPIIMasking( params = {
logging_only=presidio_logging_only "logging_only": presidio_logging_only,
) **callback_specific_params,
}
pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params)
imported_list.append(pii_masking_object) imported_list.append(pii_masking_object)
elif isinstance(callback, str) and callback == "llamaguard_moderations": elif isinstance(callback, str) and callback == "llamaguard_moderations":
from enterprise.enterprise_hooks.llama_guard import ( from enterprise.enterprise_hooks.llama_guard import (

View file

@ -1,5 +1,5 @@
import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy.guardrails.init_guardrails import guardrail_name_config_map
from litellm.proxy.proxy_server import UserAPIKeyAuth from litellm.proxy.proxy_server import UserAPIKeyAuth
from litellm.types.guardrails import * from litellm.types.guardrails import *
@ -31,7 +31,7 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b
continue continue
# lookup the guardrail in guardrail_name_config_map # lookup the guardrail in guardrail_name_config_map
guardrail_item: GuardrailItem = guardrail_name_config_map[ guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[
_guardrail_name _guardrail_name
] ]
@ -80,7 +80,9 @@ async def should_proceed_based_on_api_key(
continue continue
# lookup the guardrail in guardrail_name_config_map # lookup the guardrail in guardrail_name_config_map
guardrail_item: GuardrailItem = guardrail_name_config_map[_guardrail_name] guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[
_guardrail_name
]
guardrail_callbacks = guardrail_item.callbacks guardrail_callbacks = guardrail_item.callbacks
if guardrail_name in guardrail_callbacks: if guardrail_name in guardrail_callbacks:

View file

@ -6,15 +6,13 @@ from pydantic import BaseModel, RootModel
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.types.guardrails import GuardrailItem from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
all_guardrails: List[GuardrailItem] = [] all_guardrails: List[GuardrailItem] = []
guardrail_name_config_map: Dict[str, GuardrailItem] = {}
def initialize_guardrails( def initialize_guardrails(
guardrails_config: list, guardrails_config: List[Dict[str, GuardrailItemSpec]],
premium_user: bool, premium_user: bool,
config_file_path: str, config_file_path: str,
litellm_settings: dict, litellm_settings: dict,
@ -28,14 +26,14 @@ def initialize_guardrails(
{'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True}} {'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True}}
""" """
for k, v in item.items(): for k, v in item.items():
guardrail_item = GuardrailItem(**v, guardrail_name=k) guardrail_item = GuardrailItem(**v, guardrail_name=k)
all_guardrails.append(guardrail_item) all_guardrails.append(guardrail_item)
guardrail_name_config_map[k] = guardrail_item litellm.guardrail_name_config_map[k] = guardrail_item
# set appropriate callbacks if they are default on # set appropriate callbacks if they are default on
default_on_callbacks = set() default_on_callbacks = set()
callback_specific_params = {}
for guardrail in all_guardrails: for guardrail in all_guardrails:
verbose_proxy_logger.debug(guardrail.guardrail_name) verbose_proxy_logger.debug(guardrail.guardrail_name)
verbose_proxy_logger.debug(guardrail.default_on) verbose_proxy_logger.debug(guardrail.default_on)
@ -46,6 +44,10 @@ def initialize_guardrails(
if callback not in litellm.callbacks: if callback not in litellm.callbacks:
default_on_callbacks.add(callback) default_on_callbacks.add(callback)
if guardrail.logging_only is True:
if callback == "presidio":
callback_specific_params["logging_only"] = True
default_on_callbacks_list = list(default_on_callbacks) default_on_callbacks_list = list(default_on_callbacks)
if len(default_on_callbacks_list) > 0: if len(default_on_callbacks_list) > 0:
initialize_callbacks_on_proxy( initialize_callbacks_on_proxy(
@ -53,9 +55,10 @@ def initialize_guardrails(
premium_user=premium_user, premium_user=premium_user,
config_file_path=config_file_path, config_file_path=config_file_path,
litellm_settings=litellm_settings, litellm_settings=litellm_settings,
callback_specific_params=callback_specific_params,
) )
return guardrail_name_config_map return litellm.guardrail_name_config_map
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
"error initializing guardrails {}\n{}".format( "error initializing guardrails {}\n{}".format(

View file

@ -263,3 +263,43 @@ async def test_presidio_pii_masking_logging_output_only_logged_response():
mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"] mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"]
== "My name is <PERSON>, who are you? Say my name in your response" == "My name is <PERSON>, who are you? Say my name in your response"
) )
@pytest.mark.asyncio
async def test_presidio_pii_masking_logging_output_only_logged_response_guardrails_config():
from typing import Dict, List, Optional
import litellm
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
guardrails_config: List[Dict[str, GuardrailItemSpec]] = [
{
"pii_masking": {
"callbacks": ["presidio"],
"default_on": True,
"logging_only": True,
}
}
]
litellm_settings = {"guardrails": guardrails_config}
assert len(litellm.guardrail_name_config_map) == 0
initialize_guardrails(
guardrails_config=guardrails_config,
premium_user=True,
config_file_path="",
litellm_settings=litellm_settings,
)
assert len(litellm.guardrail_name_config_map) == 1
pii_masking_obj: Optional[_OPTIONAL_PresidioPIIMasking] = None
for callback in litellm.callbacks:
if isinstance(callback, _OPTIONAL_PresidioPIIMasking):
pii_masking_obj = callback
assert pii_masking_obj is not None
assert hasattr(pii_masking_obj, "logging_only")
assert pii_masking_obj.logging_only is True

View file

@ -1,6 +1,7 @@
from typing import Dict, List, Optional, TypedDict, Union from typing import Dict, List, Optional, Union
from pydantic import BaseModel, RootModel from pydantic import BaseModel, RootModel
from typing_extensions import Required, TypedDict, override
""" """
Pydantic object defining how to set guardrails on litellm proxy Pydantic object defining how to set guardrails on litellm proxy
@ -16,6 +17,12 @@ litellm_settings:
""" """
class GuardrailItemSpec(TypedDict, total=False):
callbacks: Required[List[str]]
default_on: bool
logging_only: Optional[bool]
class GuardrailItem(BaseModel): class GuardrailItem(BaseModel):
callbacks: List[str] callbacks: List[str]
default_on: bool default_on: bool
@ -25,8 +32,8 @@ class GuardrailItem(BaseModel):
def __init__( def __init__(
self, self,
callbacks: List[str], callbacks: List[str],
default_on: bool,
guardrail_name: str, guardrail_name: str,
default_on: bool = False,
logging_only: Optional[bool] = None, logging_only: Optional[bool] = None,
): ):
super().__init__( super().__init__(