diff --git a/enterprise/enterprise_hooks/llm_guard.py b/enterprise/enterprise_hooks/llm_guard.py index e6b2867e5..f368610cf 100644 --- a/enterprise/enterprise_hooks/llm_guard.py +++ b/enterprise/enterprise_hooks/llm_guard.py @@ -95,6 +95,21 @@ class _ENTERPRISE_LLMGuard(CustomLogger): traceback.print_exc() raise e + def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool: + if self.llm_guard_mode == "key-specific": + # check if llm guard enabled for specific keys only + self.print_verbose( + f"user_api_key_dict.permissions: {user_api_key_dict.permissions}" + ) + if ( + user_api_key_dict.permissions.get("enable_llm_guard_check", False) + == True + ): + return True + elif self.llm_guard_mode == "all": + return True + return False + async def async_moderation_hook( self, data: dict, @@ -111,16 +126,10 @@ class _ENTERPRISE_LLMGuard(CustomLogger): f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}" ) - # check if llm guard enabled for specific keys only - if self.llm_guard_mode == "key-specific": - self.print_verbose( - f"user_api_key_dict.permissions: {user_api_key_dict.permissions}" - ) - if ( - user_api_key_dict.permissions.get("enable_llm_guard_check", False) - == False - ): - return + _proceed = self.should_proceed(user_api_key_dict=user_api_key_dict) + if _proceed == False: + return + self.print_verbose("Makes LLM Guard Check") try: assert call_type in [ diff --git a/litellm/tests/test_llm_guard.py b/litellm/tests/test_llm_guard.py index c0f7b065f..73ccf2a19 100644 --- a/litellm/tests/test_llm_guard.py +++ b/litellm/tests/test_llm_guard.py @@ -25,7 +25,6 @@ from litellm.caching import DualCache ### UNIT TESTS FOR LLM GUARD ### -# Test if PII masking works with input A @pytest.mark.asyncio async def test_llm_guard_valid_response(): """ @@ -60,7 +59,6 @@ async def test_llm_guard_valid_response(): pytest.fail(f"An exception occurred - {str(e)}") -# Test if PII masking works with input B (also test if the response != A's response) @pytest.mark.asyncio async def test_llm_guard_error_raising(): """ @@ -95,3 +93,31 @@ async def test_llm_guard_error_raising(): pytest.fail(f"Should have failed - {str(e)}") except Exception as e: pass + + +def test_llm_guard_key_specific_mode(): + """ + Tests to see if llm guard 'key-specific' permissions work + """ + litellm.llm_guard_mode = "key-specific" + + llm_guard = _ENTERPRISE_LLMGuard() + + _api_key = "sk-12345" + # NOT ENABLED + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, + ) + + should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict) + + assert should_proceed == False + + # ENABLED + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, permissions={"enable_llm_guard_check": True} + ) + + should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict) + + assert should_proceed == True