Merge pull request #2706 from BerriAI/litellm_key_llm_guardrails

feat(llm_guard.py): enable key-specific llm guard check
This commit is contained in:
Krish Dholakia 2024-03-26 19:02:11 -07:00 committed by GitHub
commit c1f8d346b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 141 additions and 34 deletions

View file

@ -12,9 +12,9 @@ Features here are behind a commercial license in our `/enterprise` folder. [**Se
:::
Features:
- ✅ Content Moderation with LLM Guard
- ✅ Content Moderation with LlamaGuard
- ✅ Content Moderation with Google Text Moderations
- ✅ Content Moderation with LLM Guard
- ✅ Reject calls from Blocked User list
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
- ✅ Don't log/store specific requests (eg confidential LLM requests)
@ -23,6 +23,71 @@ Features:
## Content Moderation
### Content Moderation with LLM Guard
Set the LLM Guard API Base in your environment
```env
LLM_GUARD_API_BASE = "http://0.0.0.0:8192" # deployed llm guard api
```
Add `llmguard_moderations` as a callback
```yaml
litellm_settings:
callbacks: ["llmguard_moderations"]
```
Now you can easily test it
- Make a regular /chat/completion call
- Check your proxy logs for any statement with `LLM Guard:`
Expected results:
```
LLM Guard: Received response - {"sanitized_prompt": "hello world", "is_valid": true, "scanners": { "Regex": 0.0 }}
```
#### Turn on/off per key
**1. Update config**
```yaml
litellm_settings:
callbacks: ["llmguard_moderations"]
llm_guard_mode: "key-specific"
```
**2. Create new key**
```bash
curl --location 'http://localhost:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"models": ["fake-openai-endpoint"],
"permissions": {
"enable_llm_guard_check": true # 👈 KEY CHANGE
}
}'
# Returns {..'key': 'my-new-key'}
```
**2. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY
--data '{"model": "fake-openai-endpoint", "messages": [
{"role": "system", "content": "Be helpful"},
{"role": "user", "content": "What do you know?"}
]
}'
```
### Content Moderation with LlamaGuard
Currently works with Sagemaker's LlamaGuard endpoint.
@ -55,32 +120,7 @@ callbacks: ["llamaguard_moderations"]
llamaguard_unsafe_content_categories: /path/to/llamaguard_prompt.txt
```
### Content Moderation with LLM Guard
Set the LLM Guard API Base in your environment
```env
LLM_GUARD_API_BASE = "http://0.0.0.0:8192" # deployed llm guard api
```
Add `llmguard_moderations` as a callback
```yaml
litellm_settings:
callbacks: ["llmguard_moderations"]
```
Now you can easily test it
- Make a regular /chat/completion call
- Check your proxy logs for any statement with `LLM Guard:`
Expected results:
```
LLM Guard: Received response - {"sanitized_prompt": "hello world", "is_valid": true, "scanners": { "Regex": 0.0 }}
```
### Content Moderation with Google Text Moderation

View file

@ -96,6 +96,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
"""

View file

@ -99,6 +99,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
"""

View file

@ -30,9 +30,12 @@ litellm.set_verbose = True
class _ENTERPRISE_LLMGuard(CustomLogger):
# Class variables or attributes
def __init__(
self, mock_testing: bool = False, mock_redacted_text: Optional[dict] = None
self,
mock_testing: bool = False,
mock_redacted_text: Optional[dict] = None,
):
self.mock_redacted_text = mock_redacted_text
self.llm_guard_mode = litellm.llm_guard_mode
if mock_testing == True: # for testing purposes only
return
self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None)
@ -92,9 +95,25 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
traceback.print_exc()
raise e
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool:
if self.llm_guard_mode == "key-specific":
# check if llm guard enabled for specific keys only
self.print_verbose(
f"user_api_key_dict.permissions: {user_api_key_dict.permissions}"
)
if (
user_api_key_dict.permissions.get("enable_llm_guard_check", False)
== True
):
return True
elif self.llm_guard_mode == "all":
return True
return False
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
"""
@ -103,7 +122,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
- Use the sanitized prompt returned
- LLM Guard can handle things like PII Masking, etc.
"""
self.print_verbose(f"Inside LLM Guard Pre-Call Hook")
self.print_verbose(
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}"
)
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict)
if _proceed == False:
return
self.print_verbose("Makes LLM Guard Check")
try:
assert call_type in [
"completion",

View file

@ -1,6 +1,6 @@
### INIT VARIABLES ###
import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any
from typing import Callable, List, Optional, Dict, Union, Any, Literal
from litellm.caching import Cache
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
@ -64,6 +64,7 @@ google_moderation_confidence_threshold: Optional[float] = None
llamaguard_unsafe_content_categories: Optional[str] = None
blocked_user_list: Optional[Union[str, List]] = None
banned_keywords_list: Optional[Union[str, List]] = None
llm_guard_mode: Literal["all", "key-specific"] = "all"
##################
logging: bool = True
caching: bool = (

View file

@ -75,6 +75,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
pass

View file

@ -199,6 +199,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
self.print_verbose(

View file

@ -6,7 +6,7 @@ from datetime import datetime
import importlib
from dotenv import load_dotenv
import urllib.parse as urlparse
from litellm._logging import verbose_proxy_logger
sys.path.append(os.getcwd())
@ -20,6 +20,8 @@ telemetry = None
def append_query_params(url, params):
from litellm._logging import verbose_proxy_logger
verbose_proxy_logger.debug(f"url: {url}")
verbose_proxy_logger.debug(f"params: {params}")
parsed_url = urlparse.urlparse(url)

View file

@ -3168,7 +3168,9 @@ async def chat_completion(
tasks = []
tasks.append(
proxy_logging_obj.during_call_hook(data=data, call_type="completion")
proxy_logging_obj.during_call_hook(
data=data, user_api_key_dict=user_api_key_dict, call_type="completion"
)
)
start_time = time.time()

View file

@ -141,6 +141,7 @@ class ProxyLogging:
async def during_call_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal[
"completion",
"embeddings",
@ -157,7 +158,9 @@ class ProxyLogging:
try:
if isinstance(callback, CustomLogger):
await callback.async_moderation_hook(
data=new_data, call_type=call_type
data=new_data,
user_api_key_dict=user_api_key_dict,
call_type=call_type,
)
except Exception as e:
raise e

View file

@ -25,7 +25,6 @@ from litellm.caching import DualCache
### UNIT TESTS FOR LLM GUARD ###
# Test if PII masking works with input A
@pytest.mark.asyncio
async def test_llm_guard_valid_response():
"""
@ -54,13 +53,13 @@ async def test_llm_guard_valid_response():
}
]
},
user_api_key_dict=user_api_key_dict,
call_type="completion",
)
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# Test if PII masking works with input B (also test if the response != A's response)
@pytest.mark.asyncio
async def test_llm_guard_error_raising():
"""
@ -90,8 +89,37 @@ async def test_llm_guard_error_raising():
}
]
},
user_api_key_dict=user_api_key_dict,
call_type="completion",
)
pytest.fail(f"Should have failed - {str(e)}")
except Exception as e:
pass
def test_llm_guard_key_specific_mode():
"""
Tests to see if llm guard 'key-specific' permissions work
"""
litellm.llm_guard_mode = "key-specific"
llm_guard = _ENTERPRISE_LLMGuard(mock_testing=True)
_api_key = "sk-12345"
# NOT ENABLED
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
)
should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict)
assert should_proceed == False
# ENABLED
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, permissions={"enable_llm_guard_check": True}
)
should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict)
assert should_proceed == True