forked from phoenix/litellm-mirror
fix(llm_guard.py): enable request-specific llm guard flag
This commit is contained in:
parent
763e92a03e
commit
b6cd200676
4 changed files with 124 additions and 5 deletions
|
@ -74,7 +74,7 @@ curl --location 'http://localhost:4000/key/generate' \
|
|||
# Returns {..'key': 'my-new-key'}
|
||||
```
|
||||
|
||||
**2. Test it!**
|
||||
**3. Test it!**
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
|
@ -87,6 +87,76 @@ curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
|||
}'
|
||||
```
|
||||
|
||||
#### Turn on/off per request
|
||||
|
||||
**1. Update config**
|
||||
```yaml
|
||||
litellm_settings:
|
||||
callbacks: ["llmguard_moderations"]
|
||||
llm_guard_mode: "request-specific"
|
||||
```
|
||||
|
||||
**2. Create new key**
|
||||
|
||||
```bash
|
||||
curl --location 'http://localhost:4000/key/generate' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"models": ["fake-openai-endpoint"],
|
||||
}'
|
||||
|
||||
# Returns {..'key': 'my-new-key'}
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="sk-1234",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
],
|
||||
extra_body={ # pass in any provider-specific param, if not supported by openai, https://docs.litellm.ai/docs/completion/input#provider-specific-params
|
||||
"metadata": {
|
||||
"permissions": {
|
||||
"enable_llm_guard_check": True # 👈 KEY CHANGE
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="curl" label="Curl Request">
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer my-new-key' \ # 👈 TEST KEY
|
||||
--data '{"model": "fake-openai-endpoint", "messages": [
|
||||
{"role": "system", "content": "Be helpful"},
|
||||
{"role": "user", "content": "What do you know?"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### Content Moderation with LlamaGuard
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth) -> bool:
|
||||
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool:
|
||||
if self.llm_guard_mode == "key-specific":
|
||||
# check if llm guard enabled for specific keys only
|
||||
self.print_verbose(
|
||||
|
@ -108,6 +108,15 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
return True
|
||||
elif self.llm_guard_mode == "all":
|
||||
return True
|
||||
elif self.llm_guard_mode == "request-specific":
|
||||
self.print_verbose(f"received metadata: {data.get('metadata', {})}")
|
||||
metadata = data.get("metadata", {})
|
||||
permissions = metadata.get("permissions", {})
|
||||
if (
|
||||
"enable_llm_guard_check" in permissions
|
||||
and permissions["enable_llm_guard_check"] == True
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def async_moderation_hook(
|
||||
|
@ -126,7 +135,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}"
|
||||
)
|
||||
|
||||
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict)
|
||||
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict, data=data)
|
||||
if _proceed == False:
|
||||
return
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ google_moderation_confidence_threshold: Optional[float] = None
|
|||
llamaguard_unsafe_content_categories: Optional[str] = None
|
||||
blocked_user_list: Optional[Union[str, List]] = None
|
||||
banned_keywords_list: Optional[Union[str, List]] = None
|
||||
llm_guard_mode: Literal["all", "key-specific"] = "all"
|
||||
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
||||
##################
|
||||
logging: bool = True
|
||||
caching: bool = (
|
||||
|
|
|
@ -120,6 +120,46 @@ def test_llm_guard_key_specific_mode():
|
|||
api_key=_api_key, permissions={"enable_llm_guard_check": True}
|
||||
)
|
||||
|
||||
should_proceed = llm_guard.should_proceed(user_api_key_dict=user_api_key_dict)
|
||||
request_data = {}
|
||||
|
||||
should_proceed = llm_guard.should_proceed(
|
||||
user_api_key_dict=user_api_key_dict, data=request_data
|
||||
)
|
||||
|
||||
assert should_proceed == True
|
||||
|
||||
|
||||
def test_llm_guard_request_specific_mode():
|
||||
"""
|
||||
Tests to see if llm guard 'request-specific' permissions work
|
||||
"""
|
||||
litellm.llm_guard_mode = "request-specific"
|
||||
|
||||
llm_guard = _ENTERPRISE_LLMGuard(mock_testing=True)
|
||||
|
||||
_api_key = "sk-12345"
|
||||
# NOT ENABLED
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
)
|
||||
|
||||
request_data = {}
|
||||
|
||||
should_proceed = llm_guard.should_proceed(
|
||||
user_api_key_dict=user_api_key_dict, data=request_data
|
||||
)
|
||||
|
||||
assert should_proceed == False
|
||||
|
||||
# ENABLED
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=_api_key, permissions={"enable_llm_guard_check": True}
|
||||
)
|
||||
|
||||
request_data = {"metadata": {"permissions": {"enable_llm_guard_check": True}}}
|
||||
|
||||
should_proceed = llm_guard.should_proceed(
|
||||
user_api_key_dict=user_api_key_dict, data=request_data
|
||||
)
|
||||
|
||||
assert should_proceed == True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue