diff --git a/docs/my-website/docs/proxy/enterprise.md b/docs/my-website/docs/proxy/enterprise.md index 707123924..0c72077ee 100644 --- a/docs/my-website/docs/proxy/enterprise.md +++ b/docs/my-website/docs/proxy/enterprise.md @@ -12,9 +12,9 @@ Features here are behind a commercial license in our `/enterprise` folder. [**Se ::: Features: +- ✅ Content Moderation with LLM Guard - ✅ Content Moderation with LlamaGuard - ✅ Content Moderation with Google Text Moderations -- ✅ Content Moderation with LLM Guard - ✅ Reject calls from Blocked User list - ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors) - ✅ Don't log/store specific requests (eg confidential LLM requests) @@ -23,6 +23,71 @@ Features: ## Content Moderation +### Content Moderation with LLM Guard + +Set the LLM Guard API Base in your environment + +```env +LLM_GUARD_API_BASE = "http://0.0.0.0:8192" # deployed llm guard api +``` + +Add `llmguard_moderations` as a callback + +```yaml +litellm_settings: + callbacks: ["llmguard_moderations"] +``` + +Now you can easily test it + +- Make a regular /chat/completion call + +- Check your proxy logs for any statement with `LLM Guard:` + +Expected results: + +``` +LLM Guard: Received response - {"sanitized_prompt": "hello world", "is_valid": true, "scanners": { "Regex": 0.0 }} +``` +#### Turn on/off per key + +**1. Update config** +```yaml +litellm_settings: + callbacks: ["llmguard_moderations"] + llm_guard_mode: "key-specific" +``` + +**2. Create new key** + +```bash +curl --location 'http://localhost:4000/key/generate' \ +--header 'Authorization: Bearer sk-1234' \ +--header 'Content-Type: application/json' \ +--data '{ + "models": ["fake-openai-endpoint"], + "permissions": { + "enable_llm_guard_check": true # 👈 KEY CHANGE + } +}' + +# Returns {..'key': 'my-new-key'} +``` + +**2. Test it!** + +```bash +curl --location 'http://0.0.0.0:4000/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY +--data '{"model": "fake-openai-endpoint", "messages": [ + {"role": "system", "content": "Be helpful"}, + {"role": "user", "content": "What do you know?"} + ] + }' +``` + + ### Content Moderation with LlamaGuard Currently works with Sagemaker's LlamaGuard endpoint. @@ -55,32 +120,7 @@ callbacks: ["llamaguard_moderations"] llamaguard_unsafe_content_categories: /path/to/llamaguard_prompt.txt ``` -### Content Moderation with LLM Guard -Set the LLM Guard API Base in your environment - -```env -LLM_GUARD_API_BASE = "http://0.0.0.0:8192" # deployed llm guard api -``` - -Add `llmguard_moderations` as a callback - -```yaml -litellm_settings: - callbacks: ["llmguard_moderations"] -``` - -Now you can easily test it - -- Make a regular /chat/completion call - -- Check your proxy logs for any statement with `LLM Guard:` - -Expected results: - -``` -LLM Guard: Received response - {"sanitized_prompt": "hello world", "is_valid": true, "scanners": { "Regex": 0.0 }} -``` ### Content Moderation with Google Text Moderation diff --git a/enterprise/enterprise_hooks/google_text_moderation.py b/enterprise/enterprise_hooks/google_text_moderation.py index 6226e0cff..b548006cf 100644 --- a/enterprise/enterprise_hooks/google_text_moderation.py +++ b/enterprise/enterprise_hooks/google_text_moderation.py @@ -96,6 +96,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger): async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): """ diff --git a/enterprise/enterprise_hooks/llama_guard.py b/enterprise/enterprise_hooks/llama_guard.py index 9509e9c0b..c11a9d368 100644 --- a/enterprise/enterprise_hooks/llama_guard.py +++ b/enterprise/enterprise_hooks/llama_guard.py @@ -99,6 +99,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger): async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): """ diff --git a/enterprise/enterprise_hooks/llm_guard.py b/enterprise/enterprise_hooks/llm_guard.py index e23f7c1da..f368610cf 100644 --- a/enterprise/enterprise_hooks/llm_guard.py +++ b/enterprise/enterprise_hooks/llm_guard.py @@ -30,9 +30,12 @@ litellm.set_verbose = True class _ENTERPRISE_LLMGuard(CustomLogger): # Class variables or attributes def __init__( - self, mock_testing: bool = False, mock_redacted_text: Optional[dict] = None + self, + mock_testing: bool = False, + mock_redacted_text: Optional[dict] = None, ): self.mock_redacted_text = mock_redacted_text + self.llm_guard_mode = litellm.llm_guard_mode if mock_testing == True: # for testing purposes only return self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None) @@ -92,9 +95,25 @@ class _ENTERPRISE_LLMGuard(CustomLogger): traceback.print_exc() raise e + def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool: + if self.llm_guard_mode == "key-specific": + # check if llm guard enabled for specific keys only + self.print_verbose( + f"user_api_key_dict.permissions: {user_api_key_dict.permissions}" + ) + if ( + user_api_key_dict.permissions.get("enable_llm_guard_check", False) + == True + ): + return True + elif self.llm_guard_mode == "all": + return True + return False + async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): """ @@ -103,7 +122,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger): - Use the sanitized prompt returned - LLM Guard can handle things like PII Masking, etc. """ - self.print_verbose(f"Inside LLM Guard Pre-Call Hook") + self.print_verbose( + f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}" + ) + + _proceed = self.should_proceed(user_api_key_dict=user_api_key_dict) + if _proceed == False: + return + + self.print_verbose("Makes LLM Guard Check") try: assert call_type in [ "completion", diff --git a/litellm/__init__.py b/litellm/__init__.py index 5208e5f29..ea5844320 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1,6 +1,6 @@ ### INIT VARIABLES ### import threading, requests, os -from typing import Callable, List, Optional, Dict, Union, Any +from typing import Callable, List, Optional, Dict, Union, Any, Literal from litellm.caching import Cache from litellm._logging import set_verbose, _turn_on_debug, verbose_logger from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings @@ -64,6 +64,7 @@ google_moderation_confidence_threshold: Optional[float] = None llamaguard_unsafe_content_categories: Optional[str] = None blocked_user_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None +llm_guard_mode: Literal["all", "key-specific"] = "all" ################## logging: bool = True caching: bool = ( diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index d21c751af..503b3ff9d 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -75,6 +75,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): pass diff --git a/litellm/proxy/hooks/prompt_injection_detection.py b/litellm/proxy/hooks/prompt_injection_detection.py index 69744bbd3..896046e94 100644 --- a/litellm/proxy/hooks/prompt_injection_detection.py +++ b/litellm/proxy/hooks/prompt_injection_detection.py @@ -199,6 +199,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger): async def async_moderation_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal["completion", "embeddings", "image_generation"], ): self.print_verbose( diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 51cc62860..0e14eb122 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -6,7 +6,7 @@ from datetime import datetime import importlib from dotenv import load_dotenv import urllib.parse as urlparse -from litellm._logging import verbose_proxy_logger + sys.path.append(os.getcwd()) @@ -20,6 +20,8 @@ telemetry = None def append_query_params(url, params): + from litellm._logging import verbose_proxy_logger + verbose_proxy_logger.debug(f"url: {url}") verbose_proxy_logger.debug(f"params: {params}") parsed_url = urlparse.urlparse(url) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 209634fe5..8fa2862f2 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3168,7 +3168,9 @@ async def chat_completion( tasks = [] tasks.append( - proxy_logging_obj.during_call_hook(data=data, call_type="completion") + proxy_logging_obj.during_call_hook( + data=data, user_api_key_dict=user_api_key_dict, call_type="completion" + ) ) start_time = time.time() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index e6ba26269..ba8d70804 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -141,6 +141,7 @@ class ProxyLogging: async def during_call_hook( self, data: dict, + user_api_key_dict: UserAPIKeyAuth, call_type: Literal[ "completion", "embeddings", @@ -157,7 +158,9 @@ class ProxyLogging: try: if isinstance(callback, CustomLogger): await callback.async_moderation_hook( - data=new_data, call_type=call_type + data=new_data, + user_api_key_dict=user_api_key_dict, + call_type=call_type, ) except Exception as e: raise e diff --git a/litellm/tests/test_llm_guard.py b/litellm/tests/test_llm_guard.py index c0f7b065f..221db3213 100644 --- a/litellm/tests/test_llm_guard.py +++ b/litellm/tests/test_llm_guard.py @@ -25,7 +25,6 @@ from litellm.caching import DualCache ### UNIT TESTS FOR LLM GUARD ### -# Test if PII masking works with input A @pytest.mark.asyncio async def test_llm_guard_valid_response(): """ @@ -54,13 +53,13 @@ async def test_llm_guard_valid_response(): } ] }, + user_api_key_dict=user_api_key_dict, call_type="completion", ) except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") -# Test if PII masking works with input B (also test if the response != A's response) @pytest.mark.asyncio async def test_llm_guard_error_raising(): """ @@ -90,8 +89,37 @@ async def test_llm_guard_error_raising(): } ] }, + user_api_key_dict=user_api_key_dict, call_type="completion", ) pytest.fail(f"Should have failed - {str(e)}") except Exception as e: pass + + +def test_llm_guard_key_specific_mode(): + """ + Tests to see if llm guard 'key-specific' permissions work + """ + litellm.llm_guard_mode = "key-specific" + + llm_guard = _ENTERPRISE_LLMGuard(mock_testing=True) + + _api_key = "sk-12345" + # NOT ENABLED + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, + ) + + should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict) + + assert should_proceed == False + + # ENABLED + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, permissions={"enable_llm_guard_check": True} + ) + + should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict) + + assert should_proceed == True