diff --git a/enterprise/enterprise_hooks/lakera_ai.py b/enterprise/enterprise_hooks/lakera_ai.py index 642589a25..fabaea465 100644 --- a/enterprise/enterprise_hooks/lakera_ai.py +++ b/enterprise/enterprise_hooks/lakera_ai.py @@ -17,7 +17,6 @@ from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException from litellm._logging import verbose_proxy_logger -from litellm.proxy.guardrails.init_guardrails import all_guardrails from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from datetime import datetime diff --git a/litellm/__init__.py b/litellm/__init__.py index 9e28559e9..898a6fb9b 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -125,7 +125,7 @@ llamaguard_unsafe_content_categories: Optional[str] = None blocked_user_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" -guardrail_name_config_map: Optional[Dict[str, GuardrailItem]] = None +guardrail_name_config_map: Dict[str, GuardrailItem] = {} ################## ### PREVIEW FEATURES ### enable_preview_features: bool = False diff --git a/litellm/proxy/common_utils/init_callbacks.py b/litellm/proxy/common_utils/init_callbacks.py index 0395730b1..cc701d65e 100644 --- a/litellm/proxy/common_utils/init_callbacks.py +++ b/litellm/proxy/common_utils/init_callbacks.py @@ -14,6 +14,7 @@ def initialize_callbacks_on_proxy( premium_user: bool, config_file_path: str, litellm_settings: dict, + callback_specific_params: dict = {}, ): from litellm.proxy.proxy_server import prisma_client @@ -25,7 +26,6 @@ def initialize_callbacks_on_proxy( known_compatible_callbacks = list( get_args(litellm._custom_logger_compatible_callbacks_literal) ) - for callback in value: # ["presidio", ] if isinstance(callback, str) and callback in known_compatible_callbacks: imported_list.append(callback) @@ -54,9 +54,11 @@ def initialize_callbacks_on_proxy( presidio_logging_only ) # validate boolean given - pii_masking_object = _OPTIONAL_PresidioPIIMasking( - logging_only=presidio_logging_only - ) + params = { + "logging_only": presidio_logging_only, + **callback_specific_params, + } + pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params) imported_list.append(pii_masking_object) elif isinstance(callback, str) and callback == "llamaguard_moderations": from enterprise.enterprise_hooks.llama_guard import ( diff --git a/litellm/proxy/guardrails/guardrail_helpers.py b/litellm/proxy/guardrails/guardrail_helpers.py index 682428cc9..d6a081b4d 100644 --- a/litellm/proxy/guardrails/guardrail_helpers.py +++ b/litellm/proxy/guardrails/guardrail_helpers.py @@ -1,5 +1,5 @@ +import litellm from litellm._logging import verbose_proxy_logger -from litellm.proxy.guardrails.init_guardrails import guardrail_name_config_map from litellm.proxy.proxy_server import UserAPIKeyAuth from litellm.types.guardrails import * @@ -31,7 +31,7 @@ async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> b continue # lookup the guardrail in guardrail_name_config_map - guardrail_item: GuardrailItem = guardrail_name_config_map[ + guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[ _guardrail_name ] @@ -80,7 +80,9 @@ async def should_proceed_based_on_api_key( continue # lookup the guardrail in guardrail_name_config_map - guardrail_item: GuardrailItem = guardrail_name_config_map[_guardrail_name] + guardrail_item: GuardrailItem = litellm.guardrail_name_config_map[ + _guardrail_name + ] guardrail_callbacks = guardrail_item.callbacks if guardrail_name in guardrail_callbacks: diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index d171d5b91..1361a75e2 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -6,15 +6,13 @@ from pydantic import BaseModel, RootModel import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy -from litellm.types.guardrails import GuardrailItem +from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec all_guardrails: List[GuardrailItem] = [] -guardrail_name_config_map: Dict[str, GuardrailItem] = {} - def initialize_guardrails( - guardrails_config: list, + guardrails_config: List[Dict[str, GuardrailItemSpec]], premium_user: bool, config_file_path: str, litellm_settings: dict, @@ -28,14 +26,14 @@ def initialize_guardrails( {'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True}} """ - for k, v in item.items(): guardrail_item = GuardrailItem(**v, guardrail_name=k) all_guardrails.append(guardrail_item) - guardrail_name_config_map[k] = guardrail_item + litellm.guardrail_name_config_map[k] = guardrail_item # set appropriate callbacks if they are default on default_on_callbacks = set() + callback_specific_params = {} for guardrail in all_guardrails: verbose_proxy_logger.debug(guardrail.guardrail_name) verbose_proxy_logger.debug(guardrail.default_on) @@ -46,6 +44,10 @@ def initialize_guardrails( if callback not in litellm.callbacks: default_on_callbacks.add(callback) + if guardrail.logging_only is True: + if callback == "presidio": + callback_specific_params["logging_only"] = True + default_on_callbacks_list = list(default_on_callbacks) if len(default_on_callbacks_list) > 0: initialize_callbacks_on_proxy( @@ -53,9 +55,10 @@ def initialize_guardrails( premium_user=premium_user, config_file_path=config_file_path, litellm_settings=litellm_settings, + callback_specific_params=callback_specific_params, ) - return guardrail_name_config_map + return litellm.guardrail_name_config_map except Exception as e: verbose_proxy_logger.error( "error initializing guardrails {}\n{}".format( diff --git a/litellm/tests/test_presidio_masking.py b/litellm/tests/test_presidio_masking.py index b4d24bfbe..2885c07f3 100644 --- a/litellm/tests/test_presidio_masking.py +++ b/litellm/tests/test_presidio_masking.py @@ -263,3 +263,43 @@ async def test_presidio_pii_masking_logging_output_only_logged_response(): mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"] == "My name is , who are you? Say my name in your response" ) + + +@pytest.mark.asyncio +async def test_presidio_pii_masking_logging_output_only_logged_response_guardrails_config(): + from typing import Dict, List, Optional + + import litellm + from litellm.proxy.guardrails.init_guardrails import initialize_guardrails + from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec + + guardrails_config: List[Dict[str, GuardrailItemSpec]] = [ + { + "pii_masking": { + "callbacks": ["presidio"], + "default_on": True, + "logging_only": True, + } + } + ] + litellm_settings = {"guardrails": guardrails_config} + + assert len(litellm.guardrail_name_config_map) == 0 + initialize_guardrails( + guardrails_config=guardrails_config, + premium_user=True, + config_file_path="", + litellm_settings=litellm_settings, + ) + + assert len(litellm.guardrail_name_config_map) == 1 + + pii_masking_obj: Optional[_OPTIONAL_PresidioPIIMasking] = None + for callback in litellm.callbacks: + if isinstance(callback, _OPTIONAL_PresidioPIIMasking): + pii_masking_obj = callback + + assert pii_masking_obj is not None + + assert hasattr(pii_masking_obj, "logging_only") + assert pii_masking_obj.logging_only is True diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index a5b5f5562..b6cb296e8 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -1,6 +1,7 @@ -from typing import Dict, List, Optional, TypedDict, Union +from typing import Dict, List, Optional, Union from pydantic import BaseModel, RootModel +from typing_extensions import Required, TypedDict, override """ Pydantic object defining how to set guardrails on litellm proxy @@ -16,6 +17,12 @@ litellm_settings: """ +class GuardrailItemSpec(TypedDict, total=False): + callbacks: Required[List[str]] + default_on: bool + logging_only: Optional[bool] + + class GuardrailItem(BaseModel): callbacks: List[str] default_on: bool @@ -25,8 +32,8 @@ class GuardrailItem(BaseModel): def __init__( self, callbacks: List[str], - default_on: bool, guardrail_name: str, + default_on: bool = False, logging_only: Optional[bool] = None, ): super().__init__(