From e10eb8f6fe706c19787e2cef19a1bb20fe4ab77b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Mar 2024 17:21:51 -0700 Subject: [PATCH 1/6] feat(llm_guard.py): enable key-specific llm guard check --- .../enterprise_hooks/google_text_moderation.py | 1 + enterprise/enterprise_hooks/llama_guard.py | 1 + enterprise/enterprise_hooks/llm_guard.py | 15 ++++++++++++++- litellm/__init__.py | 1 + litellm/integrations/custom_logger.py | 1 + litellm/proxy/hooks/prompt_injection_detection.py | 1 + litellm/proxy/proxy_server.py | 4 +++- litellm/proxy/utils.py | 5 ++++- 8 files changed, 26 insertions(+), 3 deletions(-) diff --git a/enterprise/enterprise_hooks/google_text_moderation.py b/enterprise/enterprise_hooks/google_text_moderation.py index 6226e0cff..b548006cf 100644 --- a/enterprise/enterprise_hooks/google_text_moderation.py +++ b/enterprise/enterprise_hooks/google_text_moderation.py @@ -96,6 +96,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger): async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): """ diff --git a/enterprise/enterprise_hooks/llama_guard.py b/enterprise/enterprise_hooks/llama_guard.py index 9509e9c0b..c11a9d368 100644 --- a/enterprise/enterprise_hooks/llama_guard.py +++ b/enterprise/enterprise_hooks/llama_guard.py @@ -99,6 +99,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger): async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): """ diff --git a/enterprise/enterprise_hooks/llm_guard.py b/enterprise/enterprise_hooks/llm_guard.py index e23f7c1da..a973e1b13 100644 --- a/enterprise/enterprise_hooks/llm_guard.py +++ b/enterprise/enterprise_hooks/llm_guard.py @@ -30,9 +30,12 @@ litellm.set_verbose = True class _ENTERPRISE_LLMGuard(CustomLogger): # Class variables or attributes def __init__( - self, mock_testing: bool = False, mock_redacted_text: Optional[dict] = None + self, + mock_testing: bool = False, + mock_redacted_text: Optional[dict] = None, ): self.mock_redacted_text = mock_redacted_text + self.llm_guard_mode = litellm.llm_guard_mode if mock_testing == True: # for testing purposes only return self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None) @@ -95,6 +98,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger): async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): """ @@ -104,6 +108,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger): - LLM Guard can handle things like PII Masking, etc. """ self.print_verbose(f"Inside LLM Guard Pre-Call Hook") + + # check if llm guard enabled for specific keys only + if self.llm_guard_mode == "key-specific": + if ( + user_api_key_dict.permissions.get("enable_llm_guard_check", False) + == False + ): + return + try: assert call_type in [ "completion", diff --git a/litellm/__init__.py b/litellm/__init__.py index 5208e5f29..364160328 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -64,6 +64,7 @@ google_moderation_confidence_threshold: Optional[float] = None llamaguard_unsafe_content_categories: Optional[str] = None blocked_user_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None +llm_guard_mode: Literal["all", "key-specific"] = "all" ################## logging: bool = True caching: bool = ( diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index d21c751af..503b3ff9d 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -75,6 +75,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): pass diff --git a/litellm/proxy/hooks/prompt_injection_detection.py b/litellm/proxy/hooks/prompt_injection_detection.py index 69744bbd3..896046e94 100644 --- a/litellm/proxy/hooks/prompt_injection_detection.py +++ b/litellm/proxy/hooks/prompt_injection_detection.py @@ -199,6 +199,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger): async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): self.print_verbose( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1ccab49e4..52b806b8b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3168,7 +3168,9 @@ async def chat_completion( tasks = [] tasks.append( - proxy_logging_obj.during_call_hook(data=data, call_type="completion") + proxy_logging_obj.during_call_hook( + data=data, user_api_key_dict=user_api_key_dict, call_type="completion" + ) ) start_time = time.time() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index e6ba26269..ba8d70804 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -141,6 +141,7 @@ class ProxyLogging: async def during_call_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal[ "completion", "embeddings", @@ -157,7 +158,9 @@ class ProxyLogging: try: if isinstance(callback, CustomLogger): await callback.async_moderation_hook( - data=new_data, call_type=call_type + data=new_data, + user_api_key_dict=user_api_key_dict, + call_type=call_type, ) except Exception as e: raise e From 6d418a2920a491593003c717874524cf6e9f2ae6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Mar 2024 17:47:20 -0700 Subject: [PATCH 2/6] fix(llm_guard.py): working llm-guard 'key-specific' mode --- enterprise/enterprise_hooks/llm_guard.py | 9 +++++++-- litellm/__init__.py | 2 +- litellm/proxy/proxy_cli.py | 4 +++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/enterprise/enterprise_hooks/llm_guard.py b/enterprise/enterprise_hooks/llm_guard.py index a973e1b13..e6b2867e5 100644 --- a/enterprise/enterprise_hooks/llm_guard.py +++ b/enterprise/enterprise_hooks/llm_guard.py @@ -107,16 +107,21 @@ class _ENTERPRISE_LLMGuard(CustomLogger): - Use the sanitized prompt returned - LLM Guard can handle things like PII Masking, etc. """ - self.print_verbose(f"Inside LLM Guard Pre-Call Hook") + self.print_verbose( + f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}" + ) # check if llm guard enabled for specific keys only if self.llm_guard_mode == "key-specific": + self.print_verbose( + f"user_api_key_dict.permissions: {user_api_key_dict.permissions}" + ) if ( user_api_key_dict.permissions.get("enable_llm_guard_check", False) == False ): return - + self.print_verbose("Makes LLM Guard Check") try: assert call_type in [ "completion", diff --git a/litellm/__init__.py b/litellm/__init__.py index 364160328..ea5844320 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1,6 +1,6 @@ ### INIT VARIABLES ### import threading, requests, os -from typing import Callable, List, Optional, Dict, Union, Any +from typing import Callable, List, Optional, Dict, Union, Any, Literal from litellm.caching import Cache from litellm._logging import set_verbose, _turn_on_debug, verbose_logger from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 51cc62860..0e14eb122 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -6,7 +6,7 @@ from datetime import datetime import importlib from dotenv import load_dotenv import urllib.parse as urlparse -from litellm._logging import verbose_proxy_logger + sys.path.append(os.getcwd()) @@ -20,6 +20,8 @@ telemetry = None def append_query_params(url, params): + from litellm._logging import verbose_proxy_logger + verbose_proxy_logger.debug(f"url: {url}") verbose_proxy_logger.debug(f"params: {params}") parsed_url = urlparse.urlparse(url) From 1046a6352152fab398ee5b2318a2968816f0b1a1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Mar 2024 17:55:53 -0700 Subject: [PATCH 3/6] test(test_llm_guard.py): unit testing for key-level llm guard enabling --- enterprise/enterprise_hooks/llm_guard.py | 29 +++++++++++++++-------- litellm/tests/test_llm_guard.py | 30 ++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/enterprise/enterprise_hooks/llm_guard.py b/enterprise/enterprise_hooks/llm_guard.py index e6b2867e5..f368610cf 100644 --- a/enterprise/enterprise_hooks/llm_guard.py +++ b/enterprise/enterprise_hooks/llm_guard.py @@ -95,6 +95,21 @@ class _ENTERPRISE_LLMGuard(CustomLogger): traceback.print_exc() raise e + def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool: + if self.llm_guard_mode == "key-specific": + # check if llm guard enabled for specific keys only + self.print_verbose( + f"user_api_key_dict.permissions: {user_api_key_dict.permissions}" + ) + if ( + user_api_key_dict.permissions.get("enable_llm_guard_check", False) + == True + ): + return True + elif self.llm_guard_mode == "all": + return True + return False + async def async_moderation_hook( self, data: dict, @@ -111,16 +126,10 @@ class _ENTERPRISE_LLMGuard(CustomLogger): f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}" ) - # check if llm guard enabled for specific keys only - if self.llm_guard_mode == "key-specific": - self.print_verbose( - f"user_api_key_dict.permissions: {user_api_key_dict.permissions}" - ) - if ( - user_api_key_dict.permissions.get("enable_llm_guard_check", False) - == False - ): - return + _proceed = self.should_proceed(user_api_key_dict=user_api_key_dict) + if _proceed == False: + return + self.print_verbose("Makes LLM Guard Check") try: assert call_type in [ diff --git a/litellm/tests/test_llm_guard.py b/litellm/tests/test_llm_guard.py index c0f7b065f..73ccf2a19 100644 --- a/litellm/tests/test_llm_guard.py +++ b/litellm/tests/test_llm_guard.py @@ -25,7 +25,6 @@ from litellm.caching import DualCache ### UNIT TESTS FOR LLM GUARD ### -# Test if PII masking works with input A @pytest.mark.asyncio async def test_llm_guard_valid_response(): """ @@ -60,7 +59,6 @@ async def test_llm_guard_valid_response(): pytest.fail(f"An exception occurred - {str(e)}") -# Test if PII masking works with input B (also test if the response != A's response) @pytest.mark.asyncio async def test_llm_guard_error_raising(): """ @@ -95,3 +93,31 @@ async def test_llm_guard_error_raising(): pytest.fail(f"Should have failed - {str(e)}") except Exception as e: pass + + +def test_llm_guard_key_specific_mode(): + """ + Tests to see if llm guard 'key-specific' permissions work + """ + litellm.llm_guard_mode = "key-specific" + + llm_guard = _ENTERPRISE_LLMGuard() + + _api_key = "sk-12345" + # NOT ENABLED + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, + ) + + should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict) + + assert should_proceed == False + + # ENABLED + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, permissions={"enable_llm_guard_check": True} + ) + + should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict) + + assert should_proceed == True From bf7cc943fbad2270df806562575c14aa545ab4d0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Mar 2024 18:02:44 -0700 Subject: [PATCH 4/6] docs(enterprise.md): update docs to turn on/off llm guard per key --- docs/my-website/docs/proxy/enterprise.md | 92 +++++++++++++++++------- 1 file changed, 66 insertions(+), 26 deletions(-) diff --git a/docs/my-website/docs/proxy/enterprise.md b/docs/my-website/docs/proxy/enterprise.md index 707123924..0c72077ee 100644 --- a/docs/my-website/docs/proxy/enterprise.md +++ b/docs/my-website/docs/proxy/enterprise.md @@ -12,9 +12,9 @@ Features here are behind a commercial license in our `/enterprise` folder. [**Se ::: Features: +- ✅ Content Moderation with LLM Guard - ✅ Content Moderation with LlamaGuard - ✅ Content Moderation with Google Text Moderations -- ✅ Content Moderation with LLM Guard - ✅ Reject calls from Blocked User list - ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors) - ✅ Don't log/store specific requests (eg confidential LLM requests) @@ -23,6 +23,71 @@ Features: ## Content Moderation +### Content Moderation with LLM Guard + +Set the LLM Guard API Base in your environment + +```env +LLM_GUARD_API_BASE = "http://0.0.0.0:8192" # deployed llm guard api +``` + +Add `llmguard_moderations` as a callback + +```yaml +litellm_settings: + callbacks: ["llmguard_moderations"] +``` + +Now you can easily test it + +- Make a regular /chat/completion call + +- Check your proxy logs for any statement with `LLM Guard:` + +Expected results: + +``` +LLM Guard: Received response - {"sanitized_prompt": "hello world", "is_valid": true, "scanners": { "Regex": 0.0 }} +``` +#### Turn on/off per key + +**1. Update config** +```yaml +litellm_settings: + callbacks: ["llmguard_moderations"] + llm_guard_mode: "key-specific" +``` + +**2. Create new key** + +```bash +curl --location 'http://localhost:4000/key/generate' \ +--header 'Authorization: Bearer sk-1234' \ +--header 'Content-Type: application/json' \ +--data '{ + "models": ["fake-openai-endpoint"], + "permissions": { + "enable_llm_guard_check": true # 👈 KEY CHANGE + } +}' + +# Returns {..'key': 'my-new-key'} +``` + +**2. Test it!** + +```bash +curl --location 'http://0.0.0.0:4000/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY +--data '{"model": "fake-openai-endpoint", "messages": [ + {"role": "system", "content": "Be helpful"}, + {"role": "user", "content": "What do you know?"} + ] + }' +``` + + ### Content Moderation with LlamaGuard Currently works with Sagemaker's LlamaGuard endpoint. @@ -55,32 +120,7 @@ callbacks: ["llamaguard_moderations"] llamaguard_unsafe_content_categories: /path/to/llamaguard_prompt.txt ``` -### Content Moderation with LLM Guard -Set the LLM Guard API Base in your environment - -```env -LLM_GUARD_API_BASE = "http://0.0.0.0:8192" # deployed llm guard api -``` - -Add `llmguard_moderations` as a callback - -```yaml -litellm_settings: - callbacks: ["llmguard_moderations"] -``` - -Now you can easily test it - -- Make a regular /chat/completion call - -- Check your proxy logs for any statement with `LLM Guard:` - -Expected results: - -``` -LLM Guard: Received response - {"sanitized_prompt": "hello world", "is_valid": true, "scanners": { "Regex": 0.0 }} -``` ### Content Moderation with Google Text Moderation From f62f642393f4d60eff35e2d51f53bb09d5ad76c3 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Mar 2024 18:13:15 -0700 Subject: [PATCH 5/6] test(test_llm_guard.py): fix test --- litellm/tests/test_llm_guard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/tests/test_llm_guard.py b/litellm/tests/test_llm_guard.py index 73ccf2a19..694c0f35c 100644 --- a/litellm/tests/test_llm_guard.py +++ b/litellm/tests/test_llm_guard.py @@ -101,7 +101,7 @@ def test_llm_guard_key_specific_mode(): """ litellm.llm_guard_mode = "key-specific" - llm_guard = _ENTERPRISE_LLMGuard() + llm_guard = _ENTERPRISE_LLMGuard(mock_testing=True) _api_key = "sk-12345" # NOT ENABLED From 448848018894b4c73183d319612775b1d0bd618c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Mar 2024 18:37:27 -0700 Subject: [PATCH 6/6] test(test_llm_guard.py): fix test --- litellm/tests/test_llm_guard.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/tests/test_llm_guard.py b/litellm/tests/test_llm_guard.py index 694c0f35c..221db3213 100644 --- a/litellm/tests/test_llm_guard.py +++ b/litellm/tests/test_llm_guard.py @@ -53,6 +53,7 @@ async def test_llm_guard_valid_response(): } ] }, + user_api_key_dict=user_api_key_dict, call_type="completion", ) except Exception as e: @@ -88,6 +89,7 @@ async def test_llm_guard_error_raising(): } ] }, + user_api_key_dict=user_api_key_dict, call_type="completion", ) pytest.fail(f"Should have failed - {str(e)}")