(feat) Support Dynamic Params for guardrails (#7415)

* update CustomGuardrail

* unit test custom guardrails

* add dynamic params for aporia

* add dynamic params to bedrock guard

* add dynamic params for all guardrails

* fix linting

* fix should_run_guardrail

* _validate_premium_user

* update guardrail doc

* doc update

* update code q

* should_run_guardrail
This commit is contained in:
Ishaan Jaff 2024-12-25 16:07:29 -08:00 committed by GitHub
parent 77fa751639
commit 0ce5f9fe58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 411 additions and 21 deletions

View file

@ -114,6 +114,88 @@ curl -i http://localhost:4000/v1/chat/completions \
## Advanced
### ✨ Pass additional parameters to guardrail
:::info
✨ This is an Enterprise only feature [Contact us to get a free trial](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
Use this to pass additional parameters to the guardrail API call. e.g. things like success threshold. **[See `guardrails` spec for more details](#spec-guardrails-parameter)**
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
Set `guardrails={"aporia-pre-guard": {"extra_body": {"success_threshold": 0.9}}}` to pass additional parameters to the guardrail
In this example `success_threshold=0.9` is passed to the `aporia-pre-guard` guardrail request body
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
extra_body={
"guardrails": [
"aporia-pre-guard": {
"extra_body": {
"success_threshold": 0.9
}
}
]
}
)
print(response)
```
</TabItem>
<TabItem value="Curl" label="Curl Request">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
"guardrails": [
"aporia-pre-guard": {
"extra_body": {
"success_threshold": 0.9
}
}
]
}'
```
</TabItem>
</Tabs>
### ✨ Control Guardrails per Project (API Key)
:::info
@ -252,4 +334,43 @@ Expected response
{
"guardrails": ["aporia-pre-guard", "aporia-post-guard"]
}
```
## Spec: `guardrails` Parameter
The `guardrails` parameter can be passed to any LiteLLM Proxy endpoint (`/chat/completions`, `/completions`, `/embeddings`).
### Format Options
1. Simple List Format:
```python
"guardrails": [
"aporia-pre-guard",
"aporia-post-guard"
]
```
2. Advanced Dictionary Format:
In this format the dictionary key is `guardrail_name` you want to run
```python
"guardrails": {
"aporia-pre-guard": {
"extra_body": {
"success_threshold": 0.9,
"other_param": "value"
}
}
}
```
### Type Definition
```python
guardrails: Union[
List[str], # Simple list of guardrail names
Dict[str, DynamicGuardrailParams] # Advanced configuration
]
class DynamicGuardrailParams:
extra_body: Dict[str, Any] # Additional parameters for the guardrail
```

View file

@ -1,8 +1,8 @@
from typing import List, Optional
from typing import Dict, List, Optional, Union
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.guardrails import GuardrailEventHooks
from litellm.types.guardrails import DynamicGuardrailParams, GuardrailEventHooks
class CustomGuardrail(CustomLogger):
@ -26,9 +26,31 @@ class CustomGuardrail(CustomLogger):
)
super().__init__(**kwargs)
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
def get_guardrail_from_metadata(
self, data: dict
) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]:
"""
Returns the guardrail(s) to be run from the metadata
"""
metadata = data.get("metadata") or {}
requested_guardrails = metadata.get("guardrails") or []
return requested_guardrails
def _guardrail_is_in_requested_guardrails(
self,
requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]],
) -> bool:
for _guardrail in requested_guardrails:
if isinstance(_guardrail, dict):
if self.guardrail_name in _guardrail:
return True
elif isinstance(_guardrail, str):
if self.guardrail_name == _guardrail:
return True
return False
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
requested_guardrails = self.get_guardrail_from_metadata(data)
verbose_logger.debug(
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s",
@ -40,7 +62,7 @@ class CustomGuardrail(CustomLogger):
if (
self.event_hook
and self.guardrail_name not in requested_guardrails
and not self._guardrail_is_in_requested_guardrails(requested_guardrails)
and event_type.value != "logging_only"
):
return False
@ -49,3 +71,51 @@ class CustomGuardrail(CustomLogger):
return False
return True
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
"""
Returns `extra_body` to be added to the request body for the Guardrail API call
Use this to pass dynamic params to the guardrail API call - eg. success_threshold, failure_threshold, etc.
```
[{"lakera_guard": {"extra_body": {"foo": "bar"}}}]
```
Will return: for guardrail=`lakera-guard`:
{
"foo": "bar"
}
Args:
request_data: The original `request_data` passed to LiteLLM Proxy
"""
requested_guardrails = self.get_guardrail_from_metadata(request_data)
# Look for the guardrail configuration matching self.guardrail_name
for guardrail in requested_guardrails:
if isinstance(guardrail, dict) and self.guardrail_name in guardrail:
# Get the configuration for this guardrail
guardrail_config: DynamicGuardrailParams = DynamicGuardrailParams(
**guardrail[self.guardrail_name]
)
if self._validate_premium_user() is not True:
return {}
# Return the extra_body if it exists, otherwise empty dict
return guardrail_config.get("extra_body", {})
return {}
def _validate_premium_user(self) -> bool:
"""
Returns True if the user is a premium user
"""
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
if premium_user is not True:
verbose_logger.warning(
f"Trying to use premium guardrail without premium user {CommonProxyErrors.not_premium_user.value}"
)
return False
return True

View file

@ -86,12 +86,19 @@ class AporiaGuardrail(CustomGuardrail):
return data
async def make_aporia_api_request(
self, new_messages: List[dict], response_string: Optional[str] = None
self,
request_data: dict,
new_messages: List[dict],
response_string: Optional[str] = None,
):
data = await self.prepare_aporia_request(
new_messages=new_messages, response_string=response_string
)
data.update(
self.get_guardrail_dynamic_request_body_params(request_data=request_data)
)
_json_data = json.dumps(data)
"""
@ -155,7 +162,9 @@ class AporiaGuardrail(CustomGuardrail):
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
if response_str is not None:
await self.make_aporia_api_request(
response_string=response_str, new_messages=data.get("messages", [])
request_data=data,
response_string=response_str,
new_messages=data.get("messages", []),
)
add_guardrail_to_applied_guardrails_header(
@ -199,7 +208,10 @@ class AporiaGuardrail(CustomGuardrail):
new_messages = self.transform_messages(messages=data["messages"])
if new_messages is not None:
await self.make_aporia_api_request(new_messages=new_messages)
await self.make_aporia_api_request(
request_data=data,
new_messages=new_messages,
)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)

View file

@ -149,7 +149,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
def _prepare_request(
self,
credentials,
data: BedrockRequest,
data: dict,
optional_params: dict,
aws_region_name: str,
extra_headers: Optional[dict] = None,
@ -186,18 +186,23 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
):
credentials, aws_region_name = self._load_credentials()
request_data: BedrockRequest = self.convert_to_bedrock_format(
messages=kwargs.get("messages"), response=response
bedrock_request_data: dict = dict(
self.convert_to_bedrock_format(
messages=kwargs.get("messages"), response=response
)
)
bedrock_request_data.update(
self.get_guardrail_dynamic_request_body_params(request_data=kwargs)
)
prepared_request = self._prepare_request(
credentials=credentials,
data=request_data,
data=bedrock_request_data,
optional_params=self.optional_params,
aws_region_name=aws_region_name,
)
verbose_proxy_logger.debug(
"Bedrock AI request body: %s, url %s, headers: %s",
request_data,
bedrock_request_data,
prepared_request.url,
prepared_request.headers,
)

View file

@ -48,10 +48,13 @@ class GuardrailsAI(CustomGuardrail):
supported_event_hooks = [GuardrailEventHooks.post_call]
super().__init__(supported_event_hooks=supported_event_hooks, **kwargs)
async def make_guardrails_ai_api_request(self, llm_output: str):
async def make_guardrails_ai_api_request(self, llm_output: str, request_data: dict):
from httpx import URL
data = {"llmOutput": llm_output}
data = {
"llmOutput": llm_output,
**self.get_guardrail_dynamic_request_body_params(request_data=request_data),
}
_json_data = json.dumps(data)
response = await litellm.module_level_aclient.post(
url=str(
@ -96,7 +99,9 @@ class GuardrailsAI(CustomGuardrail):
response_str: str = get_content_from_model_response(response)
if response_str is not None and len(response_str) > 0:
await self.make_guardrails_ai_api_request(llm_output=response_str)
await self.make_guardrails_ai_api_request(
llm_output=response_str, request_data=data
)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name

View file

@ -216,14 +216,27 @@ class lakeraAI_Moderation(CustomGuardrail):
"Skipping lakera prompt injection, no roles with messages found"
)
return
data = {"input": lakera_input}
_json_data = json.dumps(data)
_data = {"input": lakera_input}
_json_data = json.dumps(
_data,
**self.get_guardrail_dynamic_request_body_params(request_data=data),
)
elif "input" in data and isinstance(data["input"], str):
text = data["input"]
_json_data = json.dumps({"input": text})
_json_data = json.dumps(
{
"input": text,
**self.get_guardrail_dynamic_request_body_params(request_data=data),
}
)
elif "input" in data and isinstance(data["input"], list):
text = "\n".join(data["input"])
_json_data = json.dumps({"input": text})
_json_data = json.dumps(
{
"input": text,
**self.get_guardrail_dynamic_request_body_params(request_data=data),
}
)
verbose_proxy_logger.debug("Lakera AI Request Args %s", _json_data)

View file

@ -132,6 +132,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
text: str,
output_parse_pii: bool,
presidio_config: Optional[PresidioPerRequestConfig],
request_data: dict,
) -> str:
"""
[TODO] make this more performant for high-throughput scenario
@ -150,7 +151,11 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
if self.ad_hoc_recognizers is not None:
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
# End of constructing Request 1
analyze_payload.update(
self.get_guardrail_dynamic_request_body_params(
request_data=request_data
)
)
redacted_text = None
verbose_proxy_logger.debug(
"Making request to: %s with payload: %s",
@ -235,6 +240,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
text=m["content"],
output_parse_pii=self.output_parse_pii,
presidio_config=presidio_config,
request_data=data,
)
)
responses = await asyncio.gather(*tasks)
@ -311,6 +317,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
text=text_str,
output_parse_pii=False,
presidio_config=presidio_config,
request_data=kwargs,
)
) # need to pass separately b/c presidio has context window limits
responses = await asyncio.gather(*tasks)

View file

@ -12,6 +12,14 @@ model_list:
model: bedrock/*
guardrails:
- guardrail_name: "bedrock-pre-guard"
litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "during_call"
guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT"
# for /files endpoints
# For /fine_tuning/jobs endpoints
finetune_settings:

View file

@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict, List, Literal, Optional, TypedDict
from typing import Any, Dict, List, Literal, Optional, TypedDict
from pydantic import BaseModel, ConfigDict
from typing_extensions import Required, TypedDict
@ -132,3 +132,7 @@ class BedrockContentItem(TypedDict, total=False):
class BedrockRequest(TypedDict, total=False):
source: Literal["INPUT", "OUTPUT"]
content: List[BedrockContentItem]
class DynamicGuardrailParams(TypedDict):
extra_body: Dict[str, Any]

View file

@ -0,0 +1,145 @@
import io
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import asyncio
import gzip
import json
import logging
import time
from unittest.mock import AsyncMock, patch
import pytest
import litellm
from litellm import completion
from litellm._logging import verbose_logger
from litellm.integrations.custom_guardrail import CustomGuardrail
from typing import Any, Dict, List, Literal, Optional, Union
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.types.guardrails import GuardrailEventHooks
def test_get_guardrail_from_metadata():
guardrail = CustomGuardrail(guardrail_name="test-guardrail")
# Test with empty metadata
assert guardrail.get_guardrail_from_metadata({}) == []
# Test with guardrails in metadata
data = {"metadata": {"guardrails": ["guardrail1", "guardrail2"]}}
assert guardrail.get_guardrail_from_metadata(data) == ["guardrail1", "guardrail2"]
# Test with dict guardrails
data = {
"metadata": {
"guardrails": [{"test-guardrail": {"extra_body": {"key": "value"}}}]
}
}
assert guardrail.get_guardrail_from_metadata(data) == [
{"test-guardrail": {"extra_body": {"key": "value"}}}
]
def test_guardrail_is_in_requested_guardrails():
guardrail = CustomGuardrail(guardrail_name="test-guardrail")
# Test with string list
assert (
guardrail._guardrail_is_in_requested_guardrails(["test-guardrail", "other"])
== True
)
assert guardrail._guardrail_is_in_requested_guardrails(["other"]) == False
# Test with dict list
assert (
guardrail._guardrail_is_in_requested_guardrails(
[{"test-guardrail": {"extra_body": {"extra_key": "extra_value"}}}]
)
== True
)
assert (
guardrail._guardrail_is_in_requested_guardrails(
[
{
"other-guardrail": {"extra_body": {"extra_key": "extra_value"}},
"test-guardrail": {"extra_body": {"extra_key": "extra_value"}},
}
]
)
== True
)
assert (
guardrail._guardrail_is_in_requested_guardrails(
[{"other-guardrail": {"extra_body": {"extra_key": "extra_value"}}}]
)
== False
)
def test_should_run_guardrail():
guardrail = CustomGuardrail(
guardrail_name="test-guardrail", event_hook=GuardrailEventHooks.pre_call
)
# Test matching event hook and guardrail
assert (
guardrail.should_run_guardrail(
{"metadata": {"guardrails": ["test-guardrail"]}},
GuardrailEventHooks.pre_call,
)
== True
)
# Test non-matching event hook
assert (
guardrail.should_run_guardrail(
{"metadata": {"guardrails": ["test-guardrail"]}},
GuardrailEventHooks.during_call,
)
== False
)
# Test guardrail not in requested list
assert (
guardrail.should_run_guardrail(
{"metadata": {"guardrails": ["other-guardrail"]}},
GuardrailEventHooks.pre_call,
)
== False
)
def test_get_guardrail_dynamic_request_body_params():
guardrail = CustomGuardrail(guardrail_name="test-guardrail")
# Test with no extra_body
data = {"metadata": {"guardrails": [{"test-guardrail": {}}]}}
assert guardrail.get_guardrail_dynamic_request_body_params(data) == {}
# Test with extra_body
data = {
"metadata": {
"guardrails": [{"test-guardrail": {"extra_body": {"key": "value"}}}]
}
}
assert guardrail.get_guardrail_dynamic_request_body_params(data) == {"key": "value"}
# Test with non-matching guardrail
data = {
"metadata": {
"guardrails": [{"other-guardrail": {"extra_body": {"key": "value"}}}]
}
}
assert guardrail.get_guardrail_dynamic_request_body_params(data) == {}