diff --git a/docs/my-website/docs/proxy/enterprise.md b/docs/my-website/docs/proxy/enterprise.md index 0c72077ee..97bae13a1 100644 --- a/docs/my-website/docs/proxy/enterprise.md +++ b/docs/my-website/docs/proxy/enterprise.md @@ -74,7 +74,7 @@ curl --location 'http://localhost:4000/key/generate' \ # Returns {..'key': 'my-new-key'} ``` -**2. Test it!** +**3. Test it!** ```bash curl --location 'http://0.0.0.0:4000/v1/chat/completions' \ @@ -87,6 +87,76 @@ curl --location 'http://0.0.0.0:4000/v1/chat/completions' \ }' ``` +#### Turn on/off per request + +**1. Update config** +```yaml +litellm_settings: + callbacks: ["llmguard_moderations"] + llm_guard_mode: "request-specific" +``` + +**2. Create new key** + +```bash +curl --location 'http://localhost:4000/key/generate' \ +--header 'Authorization: Bearer sk-1234' \ +--header 'Content-Type: application/json' \ +--data '{ + "models": ["fake-openai-endpoint"], +}' + +# Returns {..'key': 'my-new-key'} +``` + +**3. Test it!** + + + + +```python +import openai +client = openai.OpenAI( + api_key="sk-1234", + base_url="http://0.0.0.0:4000" +) + +# request sent to model set on litellm proxy, `litellm --model` +response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages = [ + { + "role": "user", + "content": "this is a test request, write a short poem" + } + ], + extra_body={ # pass in any provider-specific param, if not supported by openai, https://docs.litellm.ai/docs/completion/input#provider-specific-params + "metadata": { + "permissions": { + "enable_llm_guard_check": True # 👈 KEY CHANGE + }, + } + } +) + +print(response) +``` + + + +```bash +curl --location 'http://0.0.0.0:4000/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY +--data '{"model": "fake-openai-endpoint", "messages": [ + {"role": "system", "content": "Be helpful"}, + {"role": "user", "content": "What do you know?"} + ] + }' +``` + + + ### Content Moderation with LlamaGuard diff --git a/enterprise/enterprise_hooks/llm_guard.py b/enterprise/enterprise_hooks/llm_guard.py index f368610cf..3a15ca52b 100644 --- a/enterprise/enterprise_hooks/llm_guard.py +++ b/enterprise/enterprise_hooks/llm_guard.py @@ -95,7 +95,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger): traceback.print_exc() raise e - def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool: + def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool: if self.llm_guard_mode == "key-specific": # check if llm guard enabled for specific keys only self.print_verbose( @@ -108,6 +108,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger): return True elif self.llm_guard_mode == "all": return True + elif self.llm_guard_mode == "request-specific": + self.print_verbose(f"received metadata: {data.get('metadata', {})}") + metadata = data.get("metadata", {}) + permissions = metadata.get("permissions", {}) + if ( + "enable_llm_guard_check" in permissions + and permissions["enable_llm_guard_check"] == True + ): + return True return False async def async_moderation_hook( @@ -126,7 +135,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger): f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}" ) - _proceed = self.should_proceed(user_api_key_dict=user_api_key_dict) + _proceed = self.should_proceed(user_api_key_dict=user_api_key_dict, data=data) if _proceed == False: return diff --git a/litellm/__init__.py b/litellm/__init__.py index ef9374aa1..6539299a9 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -64,7 +64,7 @@ google_moderation_confidence_threshold: Optional[float] = None llamaguard_unsafe_content_categories: Optional[str] = None blocked_user_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None -llm_guard_mode: Literal["all", "key-specific"] = "all" +llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" ################## logging: bool = True caching: bool = ( diff --git a/litellm/tests/test_llm_guard.py b/litellm/tests/test_llm_guard.py index 221db3213..6b06da9af 100644 --- a/litellm/tests/test_llm_guard.py +++ b/litellm/tests/test_llm_guard.py @@ -120,6 +120,46 @@ def test_llm_guard_key_specific_mode(): api_key=_api_key, permissions={"enable_llm_guard_check": True} ) - should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict) + request_data = {} + + should_proceed = llm_guard.should_proceed( + user_api_key_dict=user_api_key_dict, data=request_data + ) + + assert should_proceed == True + + +def test_llm_guard_request_specific_mode(): + """ + Tests to see if llm guard 'request-specific' permissions work + """ + litellm.llm_guard_mode = "request-specific" + + llm_guard = _ENTERPRISE_LLMGuard(mock_testing=True) + + _api_key = "sk-12345" + # NOT ENABLED + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, + ) + + request_data = {} + + should_proceed = llm_guard.should_proceed( + user_api_key_dict=user_api_key_dict, data=request_data + ) + + assert should_proceed == False + + # ENABLED + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, permissions={"enable_llm_guard_check": True} + ) + + request_data = {"metadata": {"permissions": {"enable_llm_guard_check": True}}} + + should_proceed = llm_guard.should_proceed( + user_api_key_dict=user_api_key_dict, data=request_data + ) assert should_proceed == True