(feat) Support Dynamic Params for guardrails (#7415)

* update CustomGuardrail

* unit test custom guardrails

* add dynamic params for aporia

* add dynamic params to bedrock guard

* add dynamic params for all guardrails

* fix linting

* fix should_run_guardrail

* _validate_premium_user

* update guardrail doc

* doc update

* update code q

* should_run_guardrail
This commit is contained in:
Ishaan Jaff 2024-12-25 16:07:29 -08:00 committed by GitHub
parent 43670545b4
commit 5612103ea3
10 changed files with 411 additions and 21 deletions

View file

@ -114,6 +114,88 @@ curl -i http://localhost:4000/v1/chat/completions \
## Advanced ## Advanced
### ✨ Pass additional parameters to guardrail
:::info
✨ This is an Enterprise only feature [Contact us to get a free trial](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
Use this to pass additional parameters to the guardrail API call. e.g. things like success threshold. **[See `guardrails` spec for more details](#spec-guardrails-parameter)**
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
Set `guardrails={"aporia-pre-guard": {"extra_body": {"success_threshold": 0.9}}}` to pass additional parameters to the guardrail
In this example `success_threshold=0.9` is passed to the `aporia-pre-guard` guardrail request body
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
extra_body={
"guardrails": [
"aporia-pre-guard": {
"extra_body": {
"success_threshold": 0.9
}
}
]
}
)
print(response)
```
</TabItem>
<TabItem value="Curl" label="Curl Request">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
"guardrails": [
"aporia-pre-guard": {
"extra_body": {
"success_threshold": 0.9
}
}
]
}'
```
</TabItem>
</Tabs>
### ✨ Control Guardrails per Project (API Key) ### ✨ Control Guardrails per Project (API Key)
:::info :::info
@ -252,4 +334,43 @@ Expected response
{ {
"guardrails": ["aporia-pre-guard", "aporia-post-guard"] "guardrails": ["aporia-pre-guard", "aporia-post-guard"]
} }
```
## Spec: `guardrails` Parameter
The `guardrails` parameter can be passed to any LiteLLM Proxy endpoint (`/chat/completions`, `/completions`, `/embeddings`).
### Format Options
1. Simple List Format:
```python
"guardrails": [
"aporia-pre-guard",
"aporia-post-guard"
]
```
2. Advanced Dictionary Format:
In this format the dictionary key is `guardrail_name` you want to run
```python
"guardrails": {
"aporia-pre-guard": {
"extra_body": {
"success_threshold": 0.9,
"other_param": "value"
}
}
}
```
### Type Definition
```python
guardrails: Union[
List[str], # Simple list of guardrail names
Dict[str, DynamicGuardrailParams] # Advanced configuration
]
class DynamicGuardrailParams:
extra_body: Dict[str, Any] # Additional parameters for the guardrail
``` ```

View file

@ -1,8 +1,8 @@
from typing import List, Optional from typing import Dict, List, Optional, Union
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.types.guardrails import GuardrailEventHooks from litellm.types.guardrails import DynamicGuardrailParams, GuardrailEventHooks
class CustomGuardrail(CustomLogger): class CustomGuardrail(CustomLogger):
@ -26,9 +26,31 @@ class CustomGuardrail(CustomLogger):
) )
super().__init__(**kwargs) super().__init__(**kwargs)
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool: def get_guardrail_from_metadata(
self, data: dict
) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]:
"""
Returns the guardrail(s) to be run from the metadata
"""
metadata = data.get("metadata") or {} metadata = data.get("metadata") or {}
requested_guardrails = metadata.get("guardrails") or [] requested_guardrails = metadata.get("guardrails") or []
return requested_guardrails
def _guardrail_is_in_requested_guardrails(
self,
requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]],
) -> bool:
for _guardrail in requested_guardrails:
if isinstance(_guardrail, dict):
if self.guardrail_name in _guardrail:
return True
elif isinstance(_guardrail, str):
if self.guardrail_name == _guardrail:
return True
return False
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
requested_guardrails = self.get_guardrail_from_metadata(data)
verbose_logger.debug( verbose_logger.debug(
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s", "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s",
@ -40,7 +62,7 @@ class CustomGuardrail(CustomLogger):
if ( if (
self.event_hook self.event_hook
and self.guardrail_name not in requested_guardrails and not self._guardrail_is_in_requested_guardrails(requested_guardrails)
and event_type.value != "logging_only" and event_type.value != "logging_only"
): ):
return False return False
@ -49,3 +71,51 @@ class CustomGuardrail(CustomLogger):
return False return False
return True return True
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
"""
Returns `extra_body` to be added to the request body for the Guardrail API call
Use this to pass dynamic params to the guardrail API call - eg. success_threshold, failure_threshold, etc.
```
[{"lakera_guard": {"extra_body": {"foo": "bar"}}}]
```
Will return: for guardrail=`lakera-guard`:
{
"foo": "bar"
}
Args:
request_data: The original `request_data` passed to LiteLLM Proxy
"""
requested_guardrails = self.get_guardrail_from_metadata(request_data)
# Look for the guardrail configuration matching self.guardrail_name
for guardrail in requested_guardrails:
if isinstance(guardrail, dict) and self.guardrail_name in guardrail:
# Get the configuration for this guardrail
guardrail_config: DynamicGuardrailParams = DynamicGuardrailParams(
**guardrail[self.guardrail_name]
)
if self._validate_premium_user() is not True:
return {}
# Return the extra_body if it exists, otherwise empty dict
return guardrail_config.get("extra_body", {})
return {}
def _validate_premium_user(self) -> bool:
"""
Returns True if the user is a premium user
"""
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
if premium_user is not True:
verbose_logger.warning(
f"Trying to use premium guardrail without premium user {CommonProxyErrors.not_premium_user.value}"
)
return False
return True

View file

@ -86,12 +86,19 @@ class AporiaGuardrail(CustomGuardrail):
return data return data
async def make_aporia_api_request( async def make_aporia_api_request(
self, new_messages: List[dict], response_string: Optional[str] = None self,
request_data: dict,
new_messages: List[dict],
response_string: Optional[str] = None,
): ):
data = await self.prepare_aporia_request( data = await self.prepare_aporia_request(
new_messages=new_messages, response_string=response_string new_messages=new_messages, response_string=response_string
) )
data.update(
self.get_guardrail_dynamic_request_body_params(request_data=request_data)
)
_json_data = json.dumps(data) _json_data = json.dumps(data)
""" """
@ -155,7 +162,9 @@ class AporiaGuardrail(CustomGuardrail):
response_str: Optional[str] = convert_litellm_response_object_to_str(response) response_str: Optional[str] = convert_litellm_response_object_to_str(response)
if response_str is not None: if response_str is not None:
await self.make_aporia_api_request( await self.make_aporia_api_request(
response_string=response_str, new_messages=data.get("messages", []) request_data=data,
response_string=response_str,
new_messages=data.get("messages", []),
) )
add_guardrail_to_applied_guardrails_header( add_guardrail_to_applied_guardrails_header(
@ -199,7 +208,10 @@ class AporiaGuardrail(CustomGuardrail):
new_messages = self.transform_messages(messages=data["messages"]) new_messages = self.transform_messages(messages=data["messages"])
if new_messages is not None: if new_messages is not None:
await self.make_aporia_api_request(new_messages=new_messages) await self.make_aporia_api_request(
request_data=data,
new_messages=new_messages,
)
add_guardrail_to_applied_guardrails_header( add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name request_data=data, guardrail_name=self.guardrail_name
) )

View file

@ -149,7 +149,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
def _prepare_request( def _prepare_request(
self, self,
credentials, credentials,
data: BedrockRequest, data: dict,
optional_params: dict, optional_params: dict,
aws_region_name: str, aws_region_name: str,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
@ -186,18 +186,23 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
): ):
credentials, aws_region_name = self._load_credentials() credentials, aws_region_name = self._load_credentials()
request_data: BedrockRequest = self.convert_to_bedrock_format( bedrock_request_data: dict = dict(
messages=kwargs.get("messages"), response=response self.convert_to_bedrock_format(
messages=kwargs.get("messages"), response=response
)
)
bedrock_request_data.update(
self.get_guardrail_dynamic_request_body_params(request_data=kwargs)
) )
prepared_request = self._prepare_request( prepared_request = self._prepare_request(
credentials=credentials, credentials=credentials,
data=request_data, data=bedrock_request_data,
optional_params=self.optional_params, optional_params=self.optional_params,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
) )
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Bedrock AI request body: %s, url %s, headers: %s", "Bedrock AI request body: %s, url %s, headers: %s",
request_data, bedrock_request_data,
prepared_request.url, prepared_request.url,
prepared_request.headers, prepared_request.headers,
) )

View file

@ -48,10 +48,13 @@ class GuardrailsAI(CustomGuardrail):
supported_event_hooks = [GuardrailEventHooks.post_call] supported_event_hooks = [GuardrailEventHooks.post_call]
super().__init__(supported_event_hooks=supported_event_hooks, **kwargs) super().__init__(supported_event_hooks=supported_event_hooks, **kwargs)
async def make_guardrails_ai_api_request(self, llm_output: str): async def make_guardrails_ai_api_request(self, llm_output: str, request_data: dict):
from httpx import URL from httpx import URL
data = {"llmOutput": llm_output} data = {
"llmOutput": llm_output,
**self.get_guardrail_dynamic_request_body_params(request_data=request_data),
}
_json_data = json.dumps(data) _json_data = json.dumps(data)
response = await litellm.module_level_aclient.post( response = await litellm.module_level_aclient.post(
url=str( url=str(
@ -96,7 +99,9 @@ class GuardrailsAI(CustomGuardrail):
response_str: str = get_content_from_model_response(response) response_str: str = get_content_from_model_response(response)
if response_str is not None and len(response_str) > 0: if response_str is not None and len(response_str) > 0:
await self.make_guardrails_ai_api_request(llm_output=response_str) await self.make_guardrails_ai_api_request(
llm_output=response_str, request_data=data
)
add_guardrail_to_applied_guardrails_header( add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name request_data=data, guardrail_name=self.guardrail_name

View file

@ -216,14 +216,27 @@ class lakeraAI_Moderation(CustomGuardrail):
"Skipping lakera prompt injection, no roles with messages found" "Skipping lakera prompt injection, no roles with messages found"
) )
return return
data = {"input": lakera_input} _data = {"input": lakera_input}
_json_data = json.dumps(data) _json_data = json.dumps(
_data,
**self.get_guardrail_dynamic_request_body_params(request_data=data),
)
elif "input" in data and isinstance(data["input"], str): elif "input" in data and isinstance(data["input"], str):
text = data["input"] text = data["input"]
_json_data = json.dumps({"input": text}) _json_data = json.dumps(
{
"input": text,
**self.get_guardrail_dynamic_request_body_params(request_data=data),
}
)
elif "input" in data and isinstance(data["input"], list): elif "input" in data and isinstance(data["input"], list):
text = "\n".join(data["input"]) text = "\n".join(data["input"])
_json_data = json.dumps({"input": text}) _json_data = json.dumps(
{
"input": text,
**self.get_guardrail_dynamic_request_body_params(request_data=data),
}
)
verbose_proxy_logger.debug("Lakera AI Request Args %s", _json_data) verbose_proxy_logger.debug("Lakera AI Request Args %s", _json_data)

View file

@ -132,6 +132,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
text: str, text: str,
output_parse_pii: bool, output_parse_pii: bool,
presidio_config: Optional[PresidioPerRequestConfig], presidio_config: Optional[PresidioPerRequestConfig],
request_data: dict,
) -> str: ) -> str:
""" """
[TODO] make this more performant for high-throughput scenario [TODO] make this more performant for high-throughput scenario
@ -150,7 +151,11 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
if self.ad_hoc_recognizers is not None: if self.ad_hoc_recognizers is not None:
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
# End of constructing Request 1 # End of constructing Request 1
analyze_payload.update(
self.get_guardrail_dynamic_request_body_params(
request_data=request_data
)
)
redacted_text = None redacted_text = None
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Making request to: %s with payload: %s", "Making request to: %s with payload: %s",
@ -235,6 +240,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
text=m["content"], text=m["content"],
output_parse_pii=self.output_parse_pii, output_parse_pii=self.output_parse_pii,
presidio_config=presidio_config, presidio_config=presidio_config,
request_data=data,
) )
) )
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
@ -311,6 +317,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
text=text_str, text=text_str,
output_parse_pii=False, output_parse_pii=False,
presidio_config=presidio_config, presidio_config=presidio_config,
request_data=kwargs,
) )
) # need to pass separately b/c presidio has context window limits ) # need to pass separately b/c presidio has context window limits
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)

View file

@ -12,6 +12,14 @@ model_list:
model: bedrock/* model: bedrock/*
guardrails:
- guardrail_name: "bedrock-pre-guard"
litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "during_call"
guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT"
# for /files endpoints # for /files endpoints
# For /fine_tuning/jobs endpoints # For /fine_tuning/jobs endpoints
finetune_settings: finetune_settings:

View file

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Dict, List, Literal, Optional, TypedDict from typing import Any, Dict, List, Literal, Optional, TypedDict
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict
@ -132,3 +132,7 @@ class BedrockContentItem(TypedDict, total=False):
class BedrockRequest(TypedDict, total=False): class BedrockRequest(TypedDict, total=False):
source: Literal["INPUT", "OUTPUT"] source: Literal["INPUT", "OUTPUT"]
content: List[BedrockContentItem] content: List[BedrockContentItem]
class DynamicGuardrailParams(TypedDict):
extra_body: Dict[str, Any]

View file

@ -0,0 +1,145 @@
import io
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import asyncio
import gzip
import json
import logging
import time
from unittest.mock import AsyncMock, patch
import pytest
import litellm
from litellm import completion
from litellm._logging import verbose_logger
from litellm.integrations.custom_guardrail import CustomGuardrail
from typing import Any, Dict, List, Literal, Optional, Union
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.types.guardrails import GuardrailEventHooks
def test_get_guardrail_from_metadata():
guardrail = CustomGuardrail(guardrail_name="test-guardrail")
# Test with empty metadata
assert guardrail.get_guardrail_from_metadata({}) == []
# Test with guardrails in metadata
data = {"metadata": {"guardrails": ["guardrail1", "guardrail2"]}}
assert guardrail.get_guardrail_from_metadata(data) == ["guardrail1", "guardrail2"]
# Test with dict guardrails
data = {
"metadata": {
"guardrails": [{"test-guardrail": {"extra_body": {"key": "value"}}}]
}
}
assert guardrail.get_guardrail_from_metadata(data) == [
{"test-guardrail": {"extra_body": {"key": "value"}}}
]
def test_guardrail_is_in_requested_guardrails():
guardrail = CustomGuardrail(guardrail_name="test-guardrail")
# Test with string list
assert (
guardrail._guardrail_is_in_requested_guardrails(["test-guardrail", "other"])
== True
)
assert guardrail._guardrail_is_in_requested_guardrails(["other"]) == False
# Test with dict list
assert (
guardrail._guardrail_is_in_requested_guardrails(
[{"test-guardrail": {"extra_body": {"extra_key": "extra_value"}}}]
)
== True
)
assert (
guardrail._guardrail_is_in_requested_guardrails(
[
{
"other-guardrail": {"extra_body": {"extra_key": "extra_value"}},
"test-guardrail": {"extra_body": {"extra_key": "extra_value"}},
}
]
)
== True
)
assert (
guardrail._guardrail_is_in_requested_guardrails(
[{"other-guardrail": {"extra_body": {"extra_key": "extra_value"}}}]
)
== False
)
def test_should_run_guardrail():
guardrail = CustomGuardrail(
guardrail_name="test-guardrail", event_hook=GuardrailEventHooks.pre_call
)
# Test matching event hook and guardrail
assert (
guardrail.should_run_guardrail(
{"metadata": {"guardrails": ["test-guardrail"]}},
GuardrailEventHooks.pre_call,
)
== True
)
# Test non-matching event hook
assert (
guardrail.should_run_guardrail(
{"metadata": {"guardrails": ["test-guardrail"]}},
GuardrailEventHooks.during_call,
)
== False
)
# Test guardrail not in requested list
assert (
guardrail.should_run_guardrail(
{"metadata": {"guardrails": ["other-guardrail"]}},
GuardrailEventHooks.pre_call,
)
== False
)
def test_get_guardrail_dynamic_request_body_params():
guardrail = CustomGuardrail(guardrail_name="test-guardrail")
# Test with no extra_body
data = {"metadata": {"guardrails": [{"test-guardrail": {}}]}}
assert guardrail.get_guardrail_dynamic_request_body_params(data) == {}
# Test with extra_body
data = {
"metadata": {
"guardrails": [{"test-guardrail": {"extra_body": {"key": "value"}}}]
}
}
assert guardrail.get_guardrail_dynamic_request_body_params(data) == {"key": "value"}
# Test with non-matching guardrail
data = {
"metadata": {
"guardrails": [{"other-guardrail": {"extra_body": {"key": "value"}}}]
}
}
assert guardrail.get_guardrail_dynamic_request_body_params(data) == {}