forked from phoenix/litellm-mirror
fix(llm_guard.py): enable request-specific llm guard flag
This commit is contained in:
parent
763e92a03e
commit
b6cd200676
4 changed files with 124 additions and 5 deletions
|
@ -74,7 +74,7 @@ curl --location 'http://localhost:4000/key/generate' \
|
||||||
# Returns {..'key': 'my-new-key'}
|
# Returns {..'key': 'my-new-key'}
|
||||||
```
|
```
|
||||||
|
|
||||||
**2. Test it!**
|
**3. Test it!**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||||
|
@ -87,6 +87,76 @@ curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Turn on/off per request
|
||||||
|
|
||||||
|
**1. Update config**
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["llmguard_moderations"]
|
||||||
|
llm_guard_mode: "request-specific"
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Create new key**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://localhost:4000/key/generate' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"models": ["fake-openai-endpoint"],
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Returns {..'key': 'my-new-key'}
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Test it!**
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="sk-1234",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_body={ # pass in any provider-specific param, if not supported by openai, https://docs.litellm.ai/docs/completion/input#provider-specific-params
|
||||||
|
"metadata": {
|
||||||
|
"permissions": {
|
||||||
|
"enable_llm_guard_check": True # 👈 KEY CHANGE
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="curl" label="Curl Request">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY
|
||||||
|
--data '{"model": "fake-openai-endpoint", "messages": [
|
||||||
|
{"role": "system", "content": "Be helpful"},
|
||||||
|
{"role": "user", "content": "What do you know?"}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
### Content Moderation with LlamaGuard
|
### Content Moderation with LlamaGuard
|
||||||
|
|
||||||
|
|
|
@ -95,7 +95,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool:
|
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool:
|
||||||
if self.llm_guard_mode == "key-specific":
|
if self.llm_guard_mode == "key-specific":
|
||||||
# check if llm guard enabled for specific keys only
|
# check if llm guard enabled for specific keys only
|
||||||
self.print_verbose(
|
self.print_verbose(
|
||||||
|
@ -108,6 +108,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
||||||
return True
|
return True
|
||||||
elif self.llm_guard_mode == "all":
|
elif self.llm_guard_mode == "all":
|
||||||
return True
|
return True
|
||||||
|
elif self.llm_guard_mode == "request-specific":
|
||||||
|
self.print_verbose(f"received metadata: {data.get('metadata', {})}")
|
||||||
|
metadata = data.get("metadata", {})
|
||||||
|
permissions = metadata.get("permissions", {})
|
||||||
|
if (
|
||||||
|
"enable_llm_guard_check" in permissions
|
||||||
|
and permissions["enable_llm_guard_check"] == True
|
||||||
|
):
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def async_moderation_hook(
|
async def async_moderation_hook(
|
||||||
|
@ -126,7 +135,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
||||||
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}"
|
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict)
|
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict, data=data)
|
||||||
if _proceed == False:
|
if _proceed == False:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,7 @@ google_moderation_confidence_threshold: Optional[float] = None
|
||||||
llamaguard_unsafe_content_categories: Optional[str] = None
|
llamaguard_unsafe_content_categories: Optional[str] = None
|
||||||
blocked_user_list: Optional[Union[str, List]] = None
|
blocked_user_list: Optional[Union[str, List]] = None
|
||||||
banned_keywords_list: Optional[Union[str, List]] = None
|
banned_keywords_list: Optional[Union[str, List]] = None
|
||||||
llm_guard_mode: Literal["all", "key-specific"] = "all"
|
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
||||||
##################
|
##################
|
||||||
logging: bool = True
|
logging: bool = True
|
||||||
caching: bool = (
|
caching: bool = (
|
||||||
|
|
|
@ -120,6 +120,46 @@ def test_llm_guard_key_specific_mode():
|
||||||
api_key=_api_key, permissions={"enable_llm_guard_check": True}
|
api_key=_api_key, permissions={"enable_llm_guard_check": True}
|
||||||
)
|
)
|
||||||
|
|
||||||
should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict)
|
request_data = {}
|
||||||
|
|
||||||
|
should_proceed = llm_guard.should_proceed(
|
||||||
|
user_api_key_dict=user_api_key_dict, data=request_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert should_proceed == True
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_guard_request_specific_mode():
|
||||||
|
"""
|
||||||
|
Tests to see if llm guard 'request-specific' permissions work
|
||||||
|
"""
|
||||||
|
litellm.llm_guard_mode = "request-specific"
|
||||||
|
|
||||||
|
llm_guard = _ENTERPRISE_LLMGuard(mock_testing=True)
|
||||||
|
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
# NOT ENABLED
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key=_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {}
|
||||||
|
|
||||||
|
should_proceed = llm_guard.should_proceed(
|
||||||
|
user_api_key_dict=user_api_key_dict, data=request_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert should_proceed == False
|
||||||
|
|
||||||
|
# ENABLED
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key=_api_key, permissions={"enable_llm_guard_check": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = {"metadata": {"permissions": {"enable_llm_guard_check": True}}}
|
||||||
|
|
||||||
|
should_proceed = llm_guard.should_proceed(
|
||||||
|
user_api_key_dict=user_api_key_dict, data=request_data
|
||||||
|
)
|
||||||
|
|
||||||
assert should_proceed == True
|
assert should_proceed == True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue