diff --git a/litellm/__init__.py b/litellm/__init__.py index f9e710609..43ea2d250 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -16,7 +16,7 @@ from litellm._logging import ( log_level, ) - +from litellm.types.guardrails import GuardrailItem from litellm.proxy._types import ( KeyManagementSystem, KeyManagementSettings, @@ -124,6 +124,7 @@ llamaguard_unsafe_content_categories: Optional[str] = None blocked_user_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" +guardrail_name_config_map: Optional[Dict[str, GuardrailItem]] = None ################## ### PREVIEW FEATURES ### enable_preview_features: bool = False diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index da9826b9b..a33803505 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -2,7 +2,7 @@ # On success, logs events to Promptlayer import os import traceback -from typing import Literal, Optional, Union +from typing import Any, Literal, Optional, Tuple, Union import dotenv @@ -90,6 +90,16 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac ): pass + async def async_logging_hook(self): + """For masking logged request/response""" + pass + + def logging_hook( + self, kwargs: dict, result: Any, call_type: str + ) -> Tuple[dict, Any]: + """For masking logged request/response. Return a modified version of the request/result.""" + return kwargs, result + async def async_moderation_hook( self, data: dict, diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 0271c5714..fde907ffe 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -655,6 +655,16 @@ class Logging: result=result, litellm_logging_obj=self ) + ## LOGGING HOOK ## + + for callback in callbacks: + if isinstance(callback, CustomLogger): + self.model_call_details["input"], result = callback.logging_hook( + kwargs=self.model_call_details, + result=result, + call_type=self.call_type, + ) + for callback in callbacks: try: litellm_params = self.model_call_details.get("litellm_params", {}) diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index 9c9fde533..d171d5b91 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -18,7 +18,7 @@ def initialize_guardrails( premium_user: bool, config_file_path: str, litellm_settings: dict, -): +) -> Dict[str, GuardrailItem]: try: verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}") global all_guardrails @@ -55,7 +55,11 @@ def initialize_guardrails( litellm_settings=litellm_settings, ) + return guardrail_name_config_map except Exception as e: - verbose_proxy_logger.error(f"error initializing guardrails {str(e)}") - traceback.print_exc() + verbose_proxy_logger.error( + "error initializing guardrails {}\n{}".format( + str(e), traceback.format_exc() + ) + ) raise e diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 69f63d985..193b27e12 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1467,12 +1467,14 @@ class ProxyConfig: + CommonProxyErrors.not_premium_user.value ) - initialize_guardrails( + guardrail_name_config_map = initialize_guardrails( guardrails_config=value, premium_user=premium_user, config_file_path=config_file_path, litellm_settings=litellm_settings, ) + + litellm.guardrail_name_config_map = guardrail_name_config_map elif key == "callbacks": initialize_callbacks_on_proxy( diff --git a/litellm/tests/test_guardrails_config.py b/litellm/tests/test_guardrails_config.py new file mode 100644 index 000000000..a086c8081 --- /dev/null +++ b/litellm/tests/test_guardrails_config.py @@ -0,0 +1,73 @@ +# What is this? +## Unit Tests for guardrails config +import asyncio +import inspect +import os +import sys +import time +import traceback +import uuid +from datetime import datetime + +import pytest +from pydantic import BaseModel + +import litellm.litellm_core_utils +import litellm.litellm_core_utils.litellm_logging + +sys.path.insert(0, os.path.abspath("../..")) +from typing import Any, List, Literal, Optional, Tuple, Union +from unittest.mock import AsyncMock, MagicMock, patch + +import litellm +from litellm import Cache, completion, embedding +from litellm.integrations.custom_logger import CustomLogger +from litellm.types.utils import LiteLLMCommonStrings + + +class CustomLoggingIntegration(CustomLogger): + def __init__(self) -> None: + super().__init__() + + def logging_hook( + self, kwargs: dict, result: Any, call_type: str + ) -> Tuple[dict, Any]: + input: Optional[Any] = kwargs.get("input", None) + messages: Optional[List] = kwargs.get("messages", None) + if call_type == "completion": + # assume input is of type messages + if input is not None and isinstance(input, list): + input[0]["content"] = "Hey, my name is [NAME]." + if messages is not None and isinstance(messages, List): + messages[0]["content"] = "Hey, my name is [NAME]." + + kwargs["input"] = input + kwargs["messages"] = messages + return kwargs, result + + +def test_guardrail_masking_logging_only(): + """ + Assert response is unmasked. + + Assert logged response is masked. + """ + callback = CustomLoggingIntegration() + + with patch.object(callback, "log_success_event", new=MagicMock()) as mock_call: + litellm.callbacks = [callback] + messages = [{"role": "user", "content": "Hey, my name is Peter."}] + response = completion( + model="gpt-3.5-turbo", messages=messages, mock_response="Hi Peter!" + ) + + assert response.choices[0].message.content == "Hi Peter!" # type: ignore + + mock_call.assert_called_once() + + print(mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"]) + + assert ( + mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"] + == "Hey, my name is [NAME]." + ) diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 7dd06a79b..3ef20aa47 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -19,4 +19,5 @@ litellm_settings: class GuardrailItem(BaseModel): callbacks: List[str] default_on: bool + logging_only: bool guardrail_name: str