diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py index 047d1b6d37..25512716cd 100644 --- a/litellm/integrations/custom_guardrail.py +++ b/litellm/integrations/custom_guardrail.py @@ -18,16 +18,16 @@ class CustomGuardrail(CustomLogger): super().__init__(**kwargs) def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool: + metadata = data.get("metadata") or {} + requested_guardrails = metadata.get("guardrails") or [] verbose_logger.debug( - "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s", + "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s", self.guardrail_name, event_type, self.event_hook, + requested_guardrails, ) - metadata = data.get("metadata") or {} - requested_guardrails = metadata.get("guardrails") or [] - if self.guardrail_name not in requested_guardrails: return False diff --git a/litellm/proxy/custom_guardrail.py b/litellm/proxy/custom_guardrail.py index bdcdcee1cb..2ed989cfd3 100644 --- a/litellm/proxy/custom_guardrail.py +++ b/litellm/proxy/custom_guardrail.py @@ -1,19 +1,5 @@ -import os -import sys - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import asyncio -import json -import sys -import traceback -import uuid -from datetime import datetime from typing import Any, Dict, List, Literal, Optional, Union -from fastapi import HTTPException - import litellm from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache @@ -48,7 +34,13 @@ class myCustomGuardrail(CustomGuardrail): "pass_through_endpoint", ], ) -> Optional[Union[Exception, str, dict]]: - # In this guardrail, if a user inputs `litellm` we will mask it. + """ + Runs before the LLM API call + Runs on only Input + Use this if you want to MODIFY the input + """ + + # In this guardrail, if a user inputs `litellm` we will mask it and then send it to the LLM _messages = data.get("messages") if _messages: for message in _messages: @@ -73,6 +65,8 @@ class myCustomGuardrail(CustomGuardrail): """ Runs in parallel to LLM API call Runs on only Input + + This can NOT modify the input, only used to reject or accept a call before going to LLM API """ # this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call @@ -83,13 +77,7 @@ class myCustomGuardrail(CustomGuardrail): _content = message.get("content") if isinstance(_content, str): if "litellm" in _content.lower(): - _content = _content.replace("litellm", "********") - message["content"] = _content - - verbose_proxy_logger.debug( - "async_pre_call_hook: Message after masking %s", _messages - ) - pass + raise ValueError("Guardrail failed words - `litellm` detected") async def async_post_call_success_hook( self, @@ -100,6 +88,8 @@ class myCustomGuardrail(CustomGuardrail): """ Runs on response from LLM API call + It can be used to reject a response + If a response contains the word "coffee" -> we will raise an exception """ verbose_proxy_logger.debug("async_pre_call_hook response: %s", response) diff --git a/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py b/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py index bdcdcee1cb..2ed989cfd3 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py +++ b/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py @@ -1,19 +1,5 @@ -import os -import sys - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import asyncio -import json -import sys -import traceback -import uuid -from datetime import datetime from typing import Any, Dict, List, Literal, Optional, Union -from fastapi import HTTPException - import litellm from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache @@ -48,7 +34,13 @@ class myCustomGuardrail(CustomGuardrail): "pass_through_endpoint", ], ) -> Optional[Union[Exception, str, dict]]: - # In this guardrail, if a user inputs `litellm` we will mask it. + """ + Runs before the LLM API call + Runs on only Input + Use this if you want to MODIFY the input + """ + + # In this guardrail, if a user inputs `litellm` we will mask it and then send it to the LLM _messages = data.get("messages") if _messages: for message in _messages: @@ -73,6 +65,8 @@ class myCustomGuardrail(CustomGuardrail): """ Runs in parallel to LLM API call Runs on only Input + + This can NOT modify the input, only used to reject or accept a call before going to LLM API """ # this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call @@ -83,13 +77,7 @@ class myCustomGuardrail(CustomGuardrail): _content = message.get("content") if isinstance(_content, str): if "litellm" in _content.lower(): - _content = _content.replace("litellm", "********") - message["content"] = _content - - verbose_proxy_logger.debug( - "async_pre_call_hook: Message after masking %s", _messages - ) - pass + raise ValueError("Guardrail failed words - `litellm` detected") async def async_post_call_success_hook( self, @@ -100,6 +88,8 @@ class myCustomGuardrail(CustomGuardrail): """ Runs on response from LLM API call + It can be used to reject a response + If a response contains the word "coffee" -> we will raise an exception """ verbose_proxy_logger.debug("async_pre_call_hook response: %s", response) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 7bdad862bc..09fc014d58 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -393,7 +393,7 @@ class ProxyLogging: try: for callback in litellm.callbacks: - _callback: Optional[CustomLogger] = None + _callback = None if isinstance(callback, str): _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( callback @@ -401,11 +401,7 @@ class ProxyLogging: else: _callback = callback # type: ignore - if ( - _callback is not None - and isinstance(_callback, CustomGuardrail) - and "pre_call_hook" in vars(_callback.__class__) - ): + if _callback is not None and isinstance(_callback, CustomGuardrail): from litellm.types.guardrails import GuardrailEventHooks if (