forked from phoenix/litellm-mirror
Merge pull request #2706 from BerriAI/litellm_key_llm_guardrails
feat(llm_guard.py): enable key-specific llm guard check
This commit is contained in:
commit
c1f8d346b8
11 changed files with 141 additions and 34 deletions
|
@ -12,9 +12,9 @@ Features here are behind a commercial license in our `/enterprise` folder. [**Se
|
|||
:::
|
||||
|
||||
Features:
|
||||
- ✅ Content Moderation with LLM Guard
|
||||
- ✅ Content Moderation with LlamaGuard
|
||||
- ✅ Content Moderation with Google Text Moderations
|
||||
- ✅ Content Moderation with LLM Guard
|
||||
- ✅ Reject calls from Blocked User list
|
||||
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
||||
- ✅ Don't log/store specific requests (eg confidential LLM requests)
|
||||
|
@ -23,6 +23,71 @@ Features:
|
|||
|
||||
|
||||
## Content Moderation
|
||||
### Content Moderation with LLM Guard
|
||||
|
||||
Set the LLM Guard API Base in your environment
|
||||
|
||||
```env
|
||||
LLM_GUARD_API_BASE = "http://0.0.0.0:8192" # deployed llm guard api
|
||||
```
|
||||
|
||||
Add `llmguard_moderations` as a callback
|
||||
|
||||
```yaml
|
||||
litellm_settings:
|
||||
callbacks: ["llmguard_moderations"]
|
||||
```
|
||||
|
||||
Now you can easily test it
|
||||
|
||||
- Make a regular /chat/completion call
|
||||
|
||||
- Check your proxy logs for any statement with `LLM Guard:`
|
||||
|
||||
Expected results:
|
||||
|
||||
```
|
||||
LLM Guard: Received response - {"sanitized_prompt": "hello world", "is_valid": true, "scanners": { "Regex": 0.0 }}
|
||||
```
|
||||
#### Turn on/off per key
|
||||
|
||||
**1. Update config**
|
||||
```yaml
|
||||
litellm_settings:
|
||||
callbacks: ["llmguard_moderations"]
|
||||
llm_guard_mode: "key-specific"
|
||||
```
|
||||
|
||||
**2. Create new key**
|
||||
|
||||
```bash
|
||||
curl --location 'http://localhost:4000/key/generate' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"models": ["fake-openai-endpoint"],
|
||||
"permissions": {
|
||||
"enable_llm_guard_check": true # 👈 KEY CHANGE
|
||||
}
|
||||
}'
|
||||
|
||||
# Returns {..'key': 'my-new-key'}
|
||||
```
|
||||
|
||||
**2. Test it!**
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY
|
||||
--data '{"model": "fake-openai-endpoint", "messages": [
|
||||
{"role": "system", "content": "Be helpful"},
|
||||
{"role": "user", "content": "What do you know?"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
### Content Moderation with LlamaGuard
|
||||
|
||||
Currently works with Sagemaker's LlamaGuard endpoint.
|
||||
|
@ -55,32 +120,7 @@ callbacks: ["llamaguard_moderations"]
|
|||
llamaguard_unsafe_content_categories: /path/to/llamaguard_prompt.txt
|
||||
```
|
||||
|
||||
### Content Moderation with LLM Guard
|
||||
|
||||
Set the LLM Guard API Base in your environment
|
||||
|
||||
```env
|
||||
LLM_GUARD_API_BASE = "http://0.0.0.0:8192" # deployed llm guard api
|
||||
```
|
||||
|
||||
Add `llmguard_moderations` as a callback
|
||||
|
||||
```yaml
|
||||
litellm_settings:
|
||||
callbacks: ["llmguard_moderations"]
|
||||
```
|
||||
|
||||
Now you can easily test it
|
||||
|
||||
- Make a regular /chat/completion call
|
||||
|
||||
- Check your proxy logs for any statement with `LLM Guard:`
|
||||
|
||||
Expected results:
|
||||
|
||||
```
|
||||
LLM Guard: Received response - {"sanitized_prompt": "hello world", "is_valid": true, "scanners": { "Regex": 0.0 }}
|
||||
```
|
||||
|
||||
### Content Moderation with Google Text Moderation
|
||||
|
||||
|
|
|
@ -96,6 +96,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
|||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -99,6 +99,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
|
|||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -30,9 +30,12 @@ litellm.set_verbose = True
|
|||
class _ENTERPRISE_LLMGuard(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self, mock_testing: bool = False, mock_redacted_text: Optional[dict] = None
|
||||
self,
|
||||
mock_testing: bool = False,
|
||||
mock_redacted_text: Optional[dict] = None,
|
||||
):
|
||||
self.mock_redacted_text = mock_redacted_text
|
||||
self.llm_guard_mode = litellm.llm_guard_mode
|
||||
if mock_testing == True: # for testing purposes only
|
||||
return
|
||||
self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None)
|
||||
|
@ -92,9 +95,25 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool:
|
||||
if self.llm_guard_mode == "key-specific":
|
||||
# check if llm guard enabled for specific keys only
|
||||
self.print_verbose(
|
||||
f"user_api_key_dict.permissions: {user_api_key_dict.permissions}"
|
||||
)
|
||||
if (
|
||||
user_api_key_dict.permissions.get("enable_llm_guard_check", False)
|
||||
== True
|
||||
):
|
||||
return True
|
||||
elif self.llm_guard_mode == "all":
|
||||
return True
|
||||
return False
|
||||
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||
):
|
||||
"""
|
||||
|
@ -103,7 +122,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
- Use the sanitized prompt returned
|
||||
- LLM Guard can handle things like PII Masking, etc.
|
||||
"""
|
||||
self.print_verbose(f"Inside LLM Guard Pre-Call Hook")
|
||||
self.print_verbose(
|
||||
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}"
|
||||
)
|
||||
|
||||
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict)
|
||||
if _proceed == False:
|
||||
return
|
||||
|
||||
self.print_verbose("Makes LLM Guard Check")
|
||||
try:
|
||||
assert call_type in [
|
||||
"completion",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
### INIT VARIABLES ###
|
||||
import threading, requests, os
|
||||
from typing import Callable, List, Optional, Dict, Union, Any
|
||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||
from litellm.caching import Cache
|
||||
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
|
||||
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
|
||||
|
@ -64,6 +64,7 @@ google_moderation_confidence_threshold: Optional[float] = None
|
|||
llamaguard_unsafe_content_categories: Optional[str] = None
|
||||
blocked_user_list: Optional[Union[str, List]] = None
|
||||
banned_keywords_list: Optional[Union[str, List]] = None
|
||||
llm_guard_mode: Literal["all", "key-specific"] = "all"
|
||||
##################
|
||||
logging: bool = True
|
||||
caching: bool = (
|
||||
|
|
|
@ -75,6 +75,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||
):
|
||||
pass
|
||||
|
|
|
@ -199,6 +199,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
|||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||
):
|
||||
self.print_verbose(
|
||||
|
|
|
@ -6,7 +6,7 @@ from datetime import datetime
|
|||
import importlib
|
||||
from dotenv import load_dotenv
|
||||
import urllib.parse as urlparse
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
|
@ -20,6 +20,8 @@ telemetry = None
|
|||
|
||||
|
||||
def append_query_params(url, params):
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
verbose_proxy_logger.debug(f"url: {url}")
|
||||
verbose_proxy_logger.debug(f"params: {params}")
|
||||
parsed_url = urlparse.urlparse(url)
|
||||
|
|
|
@ -3168,7 +3168,9 @@ async def chat_completion(
|
|||
|
||||
tasks = []
|
||||
tasks.append(
|
||||
proxy_logging_obj.during_call_hook(data=data, call_type="completion")
|
||||
proxy_logging_obj.during_call_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, call_type="completion"
|
||||
)
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
|
|
@ -141,6 +141,7 @@ class ProxyLogging:
|
|||
async def during_call_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"embeddings",
|
||||
|
@ -157,7 +158,9 @@ class ProxyLogging:
|
|||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
await callback.async_moderation_hook(
|
||||
data=new_data, call_type=call_type
|
||||
data=new_data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_type=call_type,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -25,7 +25,6 @@ from litellm.caching import DualCache
|
|||
### UNIT TESTS FOR LLM GUARD ###
|
||||
|
||||
|
||||
# Test if PII masking works with input A
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_guard_valid_response():
|
||||
"""
|
||||
|
@ -54,13 +53,13 @@ async def test_llm_guard_valid_response():
|
|||
}
|
||||
]
|
||||
},
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_type="completion",
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
# Test if PII masking works with input B (also test if the response != A's response)
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_guard_error_raising():
|
||||
"""
|
||||
|
@ -90,8 +89,37 @@ async def test_llm_guard_error_raising():
|
|||
}
|
||||
]
|
||||
},
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_type="completion",
|
||||
)
|
||||
pytest.fail(f"Should have failed - {str(e)}")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def test_llm_guard_key_specific_mode():
|
||||
"""
|
||||
Tests to see if llm guard 'key-specific' permissions work
|
||||
"""
|
||||
litellm.llm_guard_mode = "key-specific"
|
||||
|
||||
llm_guard = _ENTERPRISE_LLMGuard(mock_testing=True)
|
||||
|
||||
_api_key = "sk-12345"
|
||||
# NOT ENABLED
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
)
|
||||
|
||||
should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict)
|
||||
|
||||
assert should_proceed == False
|
||||
|
||||
# ENABLED
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=_api_key, permissions={"enable_llm_guard_check": True}
|
||||
)
|
||||
|
||||
should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict)
|
||||
|
||||
assert should_proceed == True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue