fix(llm_guard.py): enable request-specific llm guard flag

This commit is contained in:
Krrish Dholakia 2024-04-08 21:15:21 -07:00
parent 763e92a03e
commit b6cd200676
4 changed files with 124 additions and 5 deletions

View file

@ -74,7 +74,7 @@ curl --location 'http://localhost:4000/key/generate' \
# Returns {..'key': 'my-new-key'}
```
**2. Test it!**
**3. Test it!**
```bash
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
@ -87,6 +87,76 @@ curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
}'
```
#### Turn on/off per request
**1. Update config**
```yaml
litellm_settings:
callbacks: ["llmguard_moderations"]
llm_guard_mode: "request-specific"
```
**2. Create new key**
```bash
curl --location 'http://localhost:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"models": ["fake-openai-endpoint"],
}'
# Returns {..'key': 'my-new-key'}
```
**3. Test it!**
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
extra_body={ # pass in any provider-specific param, if not supported by openai, https://docs.litellm.ai/docs/completion/input#provider-specific-params
"metadata": {
"permissions": {
"enable_llm_guard_check": True # 👈 KEY CHANGE
},
}
}
)
print(response)
```
</TabItem>
<TabItem value="curl" label="Curl Request">
```bash
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY
--data '{"model": "fake-openai-endpoint", "messages": [
{"role": "system", "content": "Be helpful"},
{"role": "user", "content": "What do you know?"}
]
}'
```
</TabItem>
</Tabs>
### Content Moderation with LlamaGuard

View file

@ -95,7 +95,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
traceback.print_exc()
raise e
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool:
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool:
if self.llm_guard_mode == "key-specific":
# check if llm guard enabled for specific keys only
self.print_verbose(
@ -108,6 +108,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
return True
elif self.llm_guard_mode == "all":
return True
elif self.llm_guard_mode == "request-specific":
self.print_verbose(f"received metadata: {data.get('metadata', {})}")
metadata = data.get("metadata", {})
permissions = metadata.get("permissions", {})
if (
"enable_llm_guard_check" in permissions
and permissions["enable_llm_guard_check"] == True
):
return True
return False
async def async_moderation_hook(
@ -126,7 +135,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}"
)
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict)
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict, data=data)
if _proceed == False:
return

View file

@ -64,7 +64,7 @@ google_moderation_confidence_threshold: Optional[float] = None
llamaguard_unsafe_content_categories: Optional[str] = None
blocked_user_list: Optional[Union[str, List]] = None
banned_keywords_list: Optional[Union[str, List]] = None
llm_guard_mode: Literal["all", "key-specific"] = "all"
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
##################
logging: bool = True
caching: bool = (

View file

@ -120,6 +120,46 @@ def test_llm_guard_key_specific_mode():
api_key=_api_key, permissions={"enable_llm_guard_check": True}
)
should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict)
request_data = {}
should_proceed = llm_guard.should_proceed(
user_api_key_dict=user_api_key_dict, data=request_data
)
assert should_proceed == True
def test_llm_guard_request_specific_mode():
"""
Tests to see if llm guard 'request-specific' permissions work
"""
litellm.llm_guard_mode = "request-specific"
llm_guard = _ENTERPRISE_LLMGuard(mock_testing=True)
_api_key = "sk-12345"
# NOT ENABLED
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
)
request_data = {}
should_proceed = llm_guard.should_proceed(
user_api_key_dict=user_api_key_dict, data=request_data
)
assert should_proceed == False
# ENABLED
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, permissions={"enable_llm_guard_check": True}
)
request_data = {"metadata": {"permissions": {"enable_llm_guard_check": True}}}
should_proceed = llm_guard.should_proceed(
user_api_key_dict=user_api_key_dict, data=request_data
)
assert should_proceed == True