forked from phoenix/litellm-mirror
Merge pull request #2706 from BerriAI/litellm_key_llm_guardrails
feat(llm_guard.py): enable key-specific llm guard check
This commit is contained in:
commit
c1f8d346b8
11 changed files with 141 additions and 34 deletions
|
@ -12,9 +12,9 @@ Features here are behind a commercial license in our `/enterprise` folder. [**Se
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Features:
|
Features:
|
||||||
|
- ✅ Content Moderation with LLM Guard
|
||||||
- ✅ Content Moderation with LlamaGuard
|
- ✅ Content Moderation with LlamaGuard
|
||||||
- ✅ Content Moderation with Google Text Moderations
|
- ✅ Content Moderation with Google Text Moderations
|
||||||
- ✅ Content Moderation with LLM Guard
|
|
||||||
- ✅ Reject calls from Blocked User list
|
- ✅ Reject calls from Blocked User list
|
||||||
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
||||||
- ✅ Don't log/store specific requests (eg confidential LLM requests)
|
- ✅ Don't log/store specific requests (eg confidential LLM requests)
|
||||||
|
@ -23,6 +23,71 @@ Features:
|
||||||
|
|
||||||
|
|
||||||
## Content Moderation
|
## Content Moderation
|
||||||
|
### Content Moderation with LLM Guard
|
||||||
|
|
||||||
|
Set the LLM Guard API Base in your environment
|
||||||
|
|
||||||
|
```env
|
||||||
|
LLM_GUARD_API_BASE = "http://0.0.0.0:8192" # deployed llm guard api
|
||||||
|
```
|
||||||
|
|
||||||
|
Add `llmguard_moderations` as a callback
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["llmguard_moderations"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you can easily test it
|
||||||
|
|
||||||
|
- Make a regular /chat/completion call
|
||||||
|
|
||||||
|
- Check your proxy logs for any statement with `LLM Guard:`
|
||||||
|
|
||||||
|
Expected results:
|
||||||
|
|
||||||
|
```
|
||||||
|
LLM Guard: Received response - {"sanitized_prompt": "hello world", "is_valid": true, "scanners": { "Regex": 0.0 }}
|
||||||
|
```
|
||||||
|
#### Turn on/off per key
|
||||||
|
|
||||||
|
**1. Update config**
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["llmguard_moderations"]
|
||||||
|
llm_guard_mode: "key-specific"
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Create new key**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://localhost:4000/key/generate' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"models": ["fake-openai-endpoint"],
|
||||||
|
"permissions": {
|
||||||
|
"enable_llm_guard_check": true # 👈 KEY CHANGE
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Returns {..'key': 'my-new-key'}
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Test it!**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY
|
||||||
|
--data '{"model": "fake-openai-endpoint", "messages": [
|
||||||
|
{"role": "system", "content": "Be helpful"},
|
||||||
|
{"role": "user", "content": "What do you know?"}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Content Moderation with LlamaGuard
|
### Content Moderation with LlamaGuard
|
||||||
|
|
||||||
Currently works with Sagemaker's LlamaGuard endpoint.
|
Currently works with Sagemaker's LlamaGuard endpoint.
|
||||||
|
@ -55,32 +120,7 @@ callbacks: ["llamaguard_moderations"]
|
||||||
llamaguard_unsafe_content_categories: /path/to/llamaguard_prompt.txt
|
llamaguard_unsafe_content_categories: /path/to/llamaguard_prompt.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### Content Moderation with LLM Guard
|
|
||||||
|
|
||||||
Set the LLM Guard API Base in your environment
|
|
||||||
|
|
||||||
```env
|
|
||||||
LLM_GUARD_API_BASE = "http://0.0.0.0:8192" # deployed llm guard api
|
|
||||||
```
|
|
||||||
|
|
||||||
Add `llmguard_moderations` as a callback
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
litellm_settings:
|
|
||||||
callbacks: ["llmguard_moderations"]
|
|
||||||
```
|
|
||||||
|
|
||||||
Now you can easily test it
|
|
||||||
|
|
||||||
- Make a regular /chat/completion call
|
|
||||||
|
|
||||||
- Check your proxy logs for any statement with `LLM Guard:`
|
|
||||||
|
|
||||||
Expected results:
|
|
||||||
|
|
||||||
```
|
|
||||||
LLM Guard: Received response - {"sanitized_prompt": "hello world", "is_valid": true, "scanners": { "Regex": 0.0 }}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Content Moderation with Google Text Moderation
|
### Content Moderation with Google Text Moderation
|
||||||
|
|
||||||
|
|
|
@ -96,6 +96,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
||||||
async def async_moderation_hook(
|
async def async_moderation_hook(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -99,6 +99,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
|
||||||
async def async_moderation_hook(
|
async def async_moderation_hook(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -30,9 +30,12 @@ litellm.set_verbose = True
|
||||||
class _ENTERPRISE_LLMGuard(CustomLogger):
|
class _ENTERPRISE_LLMGuard(CustomLogger):
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(
|
def __init__(
|
||||||
self, mock_testing: bool = False, mock_redacted_text: Optional[dict] = None
|
self,
|
||||||
|
mock_testing: bool = False,
|
||||||
|
mock_redacted_text: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
self.mock_redacted_text = mock_redacted_text
|
self.mock_redacted_text = mock_redacted_text
|
||||||
|
self.llm_guard_mode = litellm.llm_guard_mode
|
||||||
if mock_testing == True: # for testing purposes only
|
if mock_testing == True: # for testing purposes only
|
||||||
return
|
return
|
||||||
self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None)
|
self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None)
|
||||||
|
@ -92,9 +95,25 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool:
|
||||||
|
if self.llm_guard_mode == "key-specific":
|
||||||
|
# check if llm guard enabled for specific keys only
|
||||||
|
self.print_verbose(
|
||||||
|
f"user_api_key_dict.permissions: {user_api_key_dict.permissions}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
user_api_key_dict.permissions.get("enable_llm_guard_check", False)
|
||||||
|
== True
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
elif self.llm_guard_mode == "all":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
async def async_moderation_hook(
|
async def async_moderation_hook(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -103,7 +122,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
||||||
- Use the sanitized prompt returned
|
- Use the sanitized prompt returned
|
||||||
- LLM Guard can handle things like PII Masking, etc.
|
- LLM Guard can handle things like PII Masking, etc.
|
||||||
"""
|
"""
|
||||||
self.print_verbose(f"Inside LLM Guard Pre-Call Hook")
|
self.print_verbose(
|
||||||
|
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict)
|
||||||
|
if _proceed == False:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.print_verbose("Makes LLM Guard Check")
|
||||||
try:
|
try:
|
||||||
assert call_type in [
|
assert call_type in [
|
||||||
"completion",
|
"completion",
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
### INIT VARIABLES ###
|
### INIT VARIABLES ###
|
||||||
import threading, requests, os
|
import threading, requests, os
|
||||||
from typing import Callable, List, Optional, Dict, Union, Any
|
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
|
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
|
||||||
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
|
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
|
||||||
|
@ -64,6 +64,7 @@ google_moderation_confidence_threshold: Optional[float] = None
|
||||||
llamaguard_unsafe_content_categories: Optional[str] = None
|
llamaguard_unsafe_content_categories: Optional[str] = None
|
||||||
blocked_user_list: Optional[Union[str, List]] = None
|
blocked_user_list: Optional[Union[str, List]] = None
|
||||||
banned_keywords_list: Optional[Union[str, List]] = None
|
banned_keywords_list: Optional[Union[str, List]] = None
|
||||||
|
llm_guard_mode: Literal["all", "key-specific"] = "all"
|
||||||
##################
|
##################
|
||||||
logging: bool = True
|
logging: bool = True
|
||||||
caching: bool = (
|
caching: bool = (
|
||||||
|
|
|
@ -75,6 +75,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
async def async_moderation_hook(
|
async def async_moderation_hook(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -199,6 +199,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
||||||
async def async_moderation_hook(
|
async def async_moderation_hook(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
):
|
):
|
||||||
self.print_verbose(
|
self.print_verbose(
|
||||||
|
|
|
@ -6,7 +6,7 @@ from datetime import datetime
|
||||||
import importlib
|
import importlib
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import urllib.parse as urlparse
|
import urllib.parse as urlparse
|
||||||
from litellm._logging import verbose_proxy_logger
|
|
||||||
|
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
|
@ -20,6 +20,8 @@ telemetry = None
|
||||||
|
|
||||||
|
|
||||||
def append_query_params(url, params):
|
def append_query_params(url, params):
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
verbose_proxy_logger.debug(f"url: {url}")
|
verbose_proxy_logger.debug(f"url: {url}")
|
||||||
verbose_proxy_logger.debug(f"params: {params}")
|
verbose_proxy_logger.debug(f"params: {params}")
|
||||||
parsed_url = urlparse.urlparse(url)
|
parsed_url = urlparse.urlparse(url)
|
||||||
|
|
|
@ -3168,7 +3168,9 @@ async def chat_completion(
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
tasks.append(
|
tasks.append(
|
||||||
proxy_logging_obj.during_call_hook(data=data, call_type="completion")
|
proxy_logging_obj.during_call_hook(
|
||||||
|
data=data, user_api_key_dict=user_api_key_dict, call_type="completion"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
|
@ -141,6 +141,7 @@ class ProxyLogging:
|
||||||
async def during_call_hook(
|
async def during_call_hook(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
call_type: Literal[
|
call_type: Literal[
|
||||||
"completion",
|
"completion",
|
||||||
"embeddings",
|
"embeddings",
|
||||||
|
@ -157,7 +158,9 @@ class ProxyLogging:
|
||||||
try:
|
try:
|
||||||
if isinstance(callback, CustomLogger):
|
if isinstance(callback, CustomLogger):
|
||||||
await callback.async_moderation_hook(
|
await callback.async_moderation_hook(
|
||||||
data=new_data, call_type=call_type
|
data=new_data,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
call_type=call_type,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -25,7 +25,6 @@ from litellm.caching import DualCache
|
||||||
### UNIT TESTS FOR LLM GUARD ###
|
### UNIT TESTS FOR LLM GUARD ###
|
||||||
|
|
||||||
|
|
||||||
# Test if PII masking works with input A
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_llm_guard_valid_response():
|
async def test_llm_guard_valid_response():
|
||||||
"""
|
"""
|
||||||
|
@ -54,13 +53,13 @@ async def test_llm_guard_valid_response():
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
call_type="completion",
|
call_type="completion",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# Test if PII masking works with input B (also test if the response != A's response)
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_llm_guard_error_raising():
|
async def test_llm_guard_error_raising():
|
||||||
"""
|
"""
|
||||||
|
@ -90,8 +89,37 @@ async def test_llm_guard_error_raising():
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
call_type="completion",
|
call_type="completion",
|
||||||
)
|
)
|
||||||
pytest.fail(f"Should have failed - {str(e)}")
|
pytest.fail(f"Should have failed - {str(e)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_guard_key_specific_mode():
|
||||||
|
"""
|
||||||
|
Tests to see if llm guard 'key-specific' permissions work
|
||||||
|
"""
|
||||||
|
litellm.llm_guard_mode = "key-specific"
|
||||||
|
|
||||||
|
llm_guard = _ENTERPRISE_LLMGuard(mock_testing=True)
|
||||||
|
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
# NOT ENABLED
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key=_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict)
|
||||||
|
|
||||||
|
assert should_proceed == False
|
||||||
|
|
||||||
|
# ENABLED
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key=_api_key, permissions={"enable_llm_guard_check": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict)
|
||||||
|
|
||||||
|
assert should_proceed == True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue