diff --git a/enterprise/enterprise_hooks/google_text_moderation.py b/enterprise/enterprise_hooks/google_text_moderation.py index 6226e0cff..b548006cf 100644 --- a/enterprise/enterprise_hooks/google_text_moderation.py +++ b/enterprise/enterprise_hooks/google_text_moderation.py @@ -96,6 +96,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger): async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): """ diff --git a/enterprise/enterprise_hooks/llama_guard.py b/enterprise/enterprise_hooks/llama_guard.py index 9509e9c0b..c11a9d368 100644 --- a/enterprise/enterprise_hooks/llama_guard.py +++ b/enterprise/enterprise_hooks/llama_guard.py @@ -99,6 +99,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger): async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): """ diff --git a/enterprise/enterprise_hooks/llm_guard.py b/enterprise/enterprise_hooks/llm_guard.py index e23f7c1da..a973e1b13 100644 --- a/enterprise/enterprise_hooks/llm_guard.py +++ b/enterprise/enterprise_hooks/llm_guard.py @@ -30,9 +30,12 @@ litellm.set_verbose = True class _ENTERPRISE_LLMGuard(CustomLogger): # Class variables or attributes def __init__( - self, mock_testing: bool = False, mock_redacted_text: Optional[dict] = None + self, + mock_testing: bool = False, + mock_redacted_text: Optional[dict] = None, ): self.mock_redacted_text = mock_redacted_text + self.llm_guard_mode = litellm.llm_guard_mode if mock_testing == True: # for testing purposes only return self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None) @@ -95,6 +98,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger): async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): """ @@ -104,6 +108,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger): - LLM Guard can handle things like PII Masking, etc. """ self.print_verbose(f"Inside LLM Guard Pre-Call Hook") + + # check if llm guard enabled for specific keys only + if self.llm_guard_mode == "key-specific": + if ( + user_api_key_dict.permissions.get("enable_llm_guard_check", False) + == False + ): + return + try: assert call_type in [ "completion", diff --git a/litellm/__init__.py b/litellm/__init__.py index 5208e5f29..364160328 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -64,6 +64,7 @@ google_moderation_confidence_threshold: Optional[float] = None llamaguard_unsafe_content_categories: Optional[str] = None blocked_user_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None +llm_guard_mode: Literal["all", "key-specific"] = "all" ################## logging: bool = True caching: bool = ( diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index d21c751af..503b3ff9d 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -75,6 +75,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): pass diff --git a/litellm/proxy/hooks/prompt_injection_detection.py b/litellm/proxy/hooks/prompt_injection_detection.py index 69744bbd3..896046e94 100644 --- a/litellm/proxy/hooks/prompt_injection_detection.py +++ b/litellm/proxy/hooks/prompt_injection_detection.py @@ -199,6 +199,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger): async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): self.print_verbose( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1ccab49e4..52b806b8b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3168,7 +3168,9 @@ async def chat_completion( tasks = [] tasks.append( - proxy_logging_obj.during_call_hook(data=data, call_type="completion") + proxy_logging_obj.during_call_hook( + data=data, user_api_key_dict=user_api_key_dict, call_type="completion" + ) ) start_time = time.time() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index e6ba26269..ba8d70804 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -141,6 +141,7 @@ class ProxyLogging: async def during_call_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal[ "completion", "embeddings", @@ -157,7 +158,9 @@ class ProxyLogging: try: if isinstance(callback, CustomLogger): await callback.async_moderation_hook( - data=new_data, call_type=call_type + data=new_data, + user_api_key_dict=user_api_key_dict, + call_type=call_type, ) except Exception as e: raise e