Merge pull request #2706 from BerriAI/litellm_key_llm_guardrails

feat(llm_guard.py): enable key-specific llm guard check
This commit is contained in:
Krish Dholakia 2024-03-26 19:02:11 -07:00 committed by GitHub
commit c1f8d346b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 141 additions and 34 deletions

View file

@ -12,9 +12,9 @@ Features here are behind a commercial license in our `/enterprise` folder. [**Se
::: :::
Features: Features:
- ✅ Content Moderation with LLM Guard
- ✅ Content Moderation with LlamaGuard - ✅ Content Moderation with LlamaGuard
- ✅ Content Moderation with Google Text Moderations - ✅ Content Moderation with Google Text Moderations
- ✅ Content Moderation with LLM Guard
- ✅ Reject calls from Blocked User list - ✅ Reject calls from Blocked User list
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors) - ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
- ✅ Don't log/store specific requests (eg confidential LLM requests) - ✅ Don't log/store specific requests (eg confidential LLM requests)
@ -23,6 +23,71 @@ Features:
## Content Moderation ## Content Moderation
### Content Moderation with LLM Guard
Set the LLM Guard API Base in your environment
```env
LLM_GUARD_API_BASE = "http://0.0.0.0:8192" # deployed llm guard api
```
Add `llmguard_moderations` as a callback
```yaml
litellm_settings:
callbacks: ["llmguard_moderations"]
```
Now you can easily test it
- Make a regular /chat/completion call
- Check your proxy logs for any statement with `LLM Guard:`
Expected results:
```
LLM Guard: Received response - {"sanitized_prompt": "hello world", "is_valid": true, "scanners": { "Regex": 0.0 }}
```
#### Turn on/off per key
**1. Update config**
```yaml
litellm_settings:
callbacks: ["llmguard_moderations"]
llm_guard_mode: "key-specific"
```
**2. Create new key**
```bash
curl --location 'http://localhost:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"models": ["fake-openai-endpoint"],
"permissions": {
"enable_llm_guard_check": true # 👈 KEY CHANGE
}
}'
# Returns {..'key': 'my-new-key'}
```
**2. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY
--data '{"model": "fake-openai-endpoint", "messages": [
{"role": "system", "content": "Be helpful"},
{"role": "user", "content": "What do you know?"}
]
}'
```
### Content Moderation with LlamaGuard ### Content Moderation with LlamaGuard
Currently works with Sagemaker's LlamaGuard endpoint. Currently works with Sagemaker's LlamaGuard endpoint.
@ -55,32 +120,7 @@ callbacks: ["llamaguard_moderations"]
llamaguard_unsafe_content_categories: /path/to/llamaguard_prompt.txt llamaguard_unsafe_content_categories: /path/to/llamaguard_prompt.txt
``` ```
### Content Moderation with LLM Guard
Set the LLM Guard API Base in your environment
```env
LLM_GUARD_API_BASE = "http://0.0.0.0:8192" # deployed llm guard api
```
Add `llmguard_moderations` as a callback
```yaml
litellm_settings:
callbacks: ["llmguard_moderations"]
```
Now you can easily test it
- Make a regular /chat/completion call
- Check your proxy logs for any statement with `LLM Guard:`
Expected results:
```
LLM Guard: Received response - {"sanitized_prompt": "hello world", "is_valid": true, "scanners": { "Regex": 0.0 }}
```
### Content Moderation with Google Text Moderation ### Content Moderation with Google Text Moderation

View file

@ -96,6 +96,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
async def async_moderation_hook( async def async_moderation_hook(
self, self,
data: dict, data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"], call_type: Literal["completion", "embeddings", "image_generation"],
): ):
""" """

View file

@ -99,6 +99,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
async def async_moderation_hook( async def async_moderation_hook(
self, self,
data: dict, data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"], call_type: Literal["completion", "embeddings", "image_generation"],
): ):
""" """

View file

@ -30,9 +30,12 @@ litellm.set_verbose = True
class _ENTERPRISE_LLMGuard(CustomLogger): class _ENTERPRISE_LLMGuard(CustomLogger):
# Class variables or attributes # Class variables or attributes
def __init__( def __init__(
self, mock_testing: bool = False, mock_redacted_text: Optional[dict] = None self,
mock_testing: bool = False,
mock_redacted_text: Optional[dict] = None,
): ):
self.mock_redacted_text = mock_redacted_text self.mock_redacted_text = mock_redacted_text
self.llm_guard_mode = litellm.llm_guard_mode
if mock_testing == True: # for testing purposes only if mock_testing == True: # for testing purposes only
return return
self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None) self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None)
@ -92,9 +95,25 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
traceback.print_exc() traceback.print_exc()
raise e raise e
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool:
if self.llm_guard_mode == "key-specific":
# check if llm guard enabled for specific keys only
self.print_verbose(
f"user_api_key_dict.permissions: {user_api_key_dict.permissions}"
)
if (
user_api_key_dict.permissions.get("enable_llm_guard_check", False)
== True
):
return True
elif self.llm_guard_mode == "all":
return True
return False
async def async_moderation_hook( async def async_moderation_hook(
self, self,
data: dict, data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"], call_type: Literal["completion", "embeddings", "image_generation"],
): ):
""" """
@ -103,7 +122,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
- Use the sanitized prompt returned - Use the sanitized prompt returned
- LLM Guard can handle things like PII Masking, etc. - LLM Guard can handle things like PII Masking, etc.
""" """
self.print_verbose(f"Inside LLM Guard Pre-Call Hook") self.print_verbose(
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}"
)
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict)
if _proceed == False:
return
self.print_verbose("Makes LLM Guard Check")
try: try:
assert call_type in [ assert call_type in [
"completion", "completion",

View file

@ -1,6 +1,6 @@
### INIT VARIABLES ### ### INIT VARIABLES ###
import threading, requests, os import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any from typing import Callable, List, Optional, Dict, Union, Any, Literal
from litellm.caching import Cache from litellm.caching import Cache
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
@ -64,6 +64,7 @@ google_moderation_confidence_threshold: Optional[float] = None
llamaguard_unsafe_content_categories: Optional[str] = None llamaguard_unsafe_content_categories: Optional[str] = None
blocked_user_list: Optional[Union[str, List]] = None blocked_user_list: Optional[Union[str, List]] = None
banned_keywords_list: Optional[Union[str, List]] = None banned_keywords_list: Optional[Union[str, List]] = None
llm_guard_mode: Literal["all", "key-specific"] = "all"
################## ##################
logging: bool = True logging: bool = True
caching: bool = ( caching: bool = (

View file

@ -75,6 +75,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
async def async_moderation_hook( async def async_moderation_hook(
self, self,
data: dict, data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"], call_type: Literal["completion", "embeddings", "image_generation"],
): ):
pass pass

View file

@ -199,6 +199,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
async def async_moderation_hook( async def async_moderation_hook(
self, self,
data: dict, data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"], call_type: Literal["completion", "embeddings", "image_generation"],
): ):
self.print_verbose( self.print_verbose(

View file

@ -6,7 +6,7 @@ from datetime import datetime
import importlib import importlib
from dotenv import load_dotenv from dotenv import load_dotenv
import urllib.parse as urlparse import urllib.parse as urlparse
from litellm._logging import verbose_proxy_logger
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
@ -20,6 +20,8 @@ telemetry = None
def append_query_params(url, params): def append_query_params(url, params):
from litellm._logging import verbose_proxy_logger
verbose_proxy_logger.debug(f"url: {url}") verbose_proxy_logger.debug(f"url: {url}")
verbose_proxy_logger.debug(f"params: {params}") verbose_proxy_logger.debug(f"params: {params}")
parsed_url = urlparse.urlparse(url) parsed_url = urlparse.urlparse(url)

View file

@ -3168,7 +3168,9 @@ async def chat_completion(
tasks = [] tasks = []
tasks.append( tasks.append(
proxy_logging_obj.during_call_hook(data=data, call_type="completion") proxy_logging_obj.during_call_hook(
data=data, user_api_key_dict=user_api_key_dict, call_type="completion"
)
) )
start_time = time.time() start_time = time.time()

View file

@ -141,6 +141,7 @@ class ProxyLogging:
async def during_call_hook( async def during_call_hook(
self, self,
data: dict, data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal[ call_type: Literal[
"completion", "completion",
"embeddings", "embeddings",
@ -157,7 +158,9 @@ class ProxyLogging:
try: try:
if isinstance(callback, CustomLogger): if isinstance(callback, CustomLogger):
await callback.async_moderation_hook( await callback.async_moderation_hook(
data=new_data, call_type=call_type data=new_data,
user_api_key_dict=user_api_key_dict,
call_type=call_type,
) )
except Exception as e: except Exception as e:
raise e raise e

View file

@ -25,7 +25,6 @@ from litellm.caching import DualCache
### UNIT TESTS FOR LLM GUARD ### ### UNIT TESTS FOR LLM GUARD ###
# Test if PII masking works with input A
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_llm_guard_valid_response(): async def test_llm_guard_valid_response():
""" """
@ -54,13 +53,13 @@ async def test_llm_guard_valid_response():
} }
] ]
}, },
user_api_key_dict=user_api_key_dict,
call_type="completion", call_type="completion",
) )
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
# Test if PII masking works with input B (also test if the response != A's response)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_llm_guard_error_raising(): async def test_llm_guard_error_raising():
""" """
@ -90,8 +89,37 @@ async def test_llm_guard_error_raising():
} }
] ]
}, },
user_api_key_dict=user_api_key_dict,
call_type="completion", call_type="completion",
) )
pytest.fail(f"Should have failed - {str(e)}") pytest.fail(f"Should have failed - {str(e)}")
except Exception as e: except Exception as e:
pass pass
def test_llm_guard_key_specific_mode():
"""
Tests to see if llm guard 'key-specific' permissions work
"""
litellm.llm_guard_mode = "key-specific"
llm_guard = _ENTERPRISE_LLMGuard(mock_testing=True)
_api_key = "sk-12345"
# NOT ENABLED
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
)
should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict)
assert should_proceed == False
# ENABLED
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, permissions={"enable_llm_guard_check": True}
)
should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict)
assert should_proceed == True