mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(feat) Support Dynamic Params for guardrails
(#7415)
* update CustomGuardrail * unit test custom guardrails * add dynamic params for aporia * add dynamic params to bedrock guard * add dynamic params for all guardrails * fix linting * fix should_run_guardrail * _validate_premium_user * update guardrail doc * doc update * update code q * should_run_guardrail
This commit is contained in:
parent
77fa751639
commit
0ce5f9fe58
10 changed files with 411 additions and 21 deletions
|
@ -114,6 +114,88 @@ curl -i http://localhost:4000/v1/chat/completions \
|
|||
|
||||
|
||||
## Advanced
|
||||
|
||||
### ✨ Pass additional parameters to guardrail
|
||||
|
||||
:::info
|
||||
|
||||
✨ This is an Enterprise only feature [Contact us to get a free trial](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||
|
||||
:::
|
||||
|
||||
|
||||
|
||||
Use this to pass additional parameters to the guardrail API call. e.g. things like success threshold. **[See `guardrails` spec for more details](#spec-guardrails-parameter)**
|
||||
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
Set `guardrails={"aporia-pre-guard": {"extra_body": {"success_threshold": 0.9}}}` to pass additional parameters to the guardrail
|
||||
|
||||
In this example `success_threshold=0.9` is passed to the `aporia-pre-guard` guardrail request body
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
],
|
||||
extra_body={
|
||||
"guardrails": [
|
||||
"aporia-pre-guard": {
|
||||
"extra_body": {
|
||||
"success_threshold": 0.9
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
|
||||
<TabItem value="Curl" label="Curl Request">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
"guardrails": [
|
||||
"aporia-pre-guard": {
|
||||
"extra_body": {
|
||||
"success_threshold": 0.9
|
||||
}
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
|
||||
</Tabs>
|
||||
|
||||
|
||||
### ✨ Control Guardrails per Project (API Key)
|
||||
|
||||
:::info
|
||||
|
@ -252,4 +334,43 @@ Expected response
|
|||
{
|
||||
"guardrails": ["aporia-pre-guard", "aporia-post-guard"]
|
||||
}
|
||||
```
|
||||
|
||||
## Spec: `guardrails` Parameter
|
||||
|
||||
The `guardrails` parameter can be passed to any LiteLLM Proxy endpoint (`/chat/completions`, `/completions`, `/embeddings`).
|
||||
|
||||
### Format Options
|
||||
|
||||
1. Simple List Format:
|
||||
```python
|
||||
"guardrails": [
|
||||
"aporia-pre-guard",
|
||||
"aporia-post-guard"
|
||||
]
|
||||
```
|
||||
|
||||
2. Advanced Dictionary Format:
|
||||
|
||||
In this format the dictionary key is `guardrail_name` you want to run
|
||||
```python
|
||||
"guardrails": {
|
||||
"aporia-pre-guard": {
|
||||
"extra_body": {
|
||||
"success_threshold": 0.9,
|
||||
"other_param": "value"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Type Definition
|
||||
```python
|
||||
guardrails: Union[
|
||||
List[str], # Simple list of guardrail names
|
||||
Dict[str, DynamicGuardrailParams] # Advanced configuration
|
||||
]
|
||||
|
||||
class DynamicGuardrailParams:
|
||||
extra_body: Dict[str, Any] # Additional parameters for the guardrail
|
||||
```
|
|
@ -1,8 +1,8 @@
|
|||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
from litellm.types.guardrails import DynamicGuardrailParams, GuardrailEventHooks
|
||||
|
||||
|
||||
class CustomGuardrail(CustomLogger):
|
||||
|
@ -26,9 +26,31 @@ class CustomGuardrail(CustomLogger):
|
|||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
||||
def get_guardrail_from_metadata(
|
||||
self, data: dict
|
||||
) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]:
|
||||
"""
|
||||
Returns the guardrail(s) to be run from the metadata
|
||||
"""
|
||||
metadata = data.get("metadata") or {}
|
||||
requested_guardrails = metadata.get("guardrails") or []
|
||||
return requested_guardrails
|
||||
|
||||
def _guardrail_is_in_requested_guardrails(
|
||||
self,
|
||||
requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]],
|
||||
) -> bool:
|
||||
for _guardrail in requested_guardrails:
|
||||
if isinstance(_guardrail, dict):
|
||||
if self.guardrail_name in _guardrail:
|
||||
return True
|
||||
elif isinstance(_guardrail, str):
|
||||
if self.guardrail_name == _guardrail:
|
||||
return True
|
||||
return False
|
||||
|
||||
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
||||
requested_guardrails = self.get_guardrail_from_metadata(data)
|
||||
|
||||
verbose_logger.debug(
|
||||
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s",
|
||||
|
@ -40,7 +62,7 @@ class CustomGuardrail(CustomLogger):
|
|||
|
||||
if (
|
||||
self.event_hook
|
||||
and self.guardrail_name not in requested_guardrails
|
||||
and not self._guardrail_is_in_requested_guardrails(requested_guardrails)
|
||||
and event_type.value != "logging_only"
|
||||
):
|
||||
return False
|
||||
|
@ -49,3 +71,51 @@ class CustomGuardrail(CustomLogger):
|
|||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
|
||||
"""
|
||||
Returns `extra_body` to be added to the request body for the Guardrail API call
|
||||
|
||||
Use this to pass dynamic params to the guardrail API call - eg. success_threshold, failure_threshold, etc.
|
||||
|
||||
```
|
||||
[{"lakera_guard": {"extra_body": {"foo": "bar"}}}]
|
||||
```
|
||||
|
||||
Will return: for guardrail=`lakera-guard`:
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
|
||||
Args:
|
||||
request_data: The original `request_data` passed to LiteLLM Proxy
|
||||
"""
|
||||
requested_guardrails = self.get_guardrail_from_metadata(request_data)
|
||||
|
||||
# Look for the guardrail configuration matching self.guardrail_name
|
||||
for guardrail in requested_guardrails:
|
||||
if isinstance(guardrail, dict) and self.guardrail_name in guardrail:
|
||||
# Get the configuration for this guardrail
|
||||
guardrail_config: DynamicGuardrailParams = DynamicGuardrailParams(
|
||||
**guardrail[self.guardrail_name]
|
||||
)
|
||||
if self._validate_premium_user() is not True:
|
||||
return {}
|
||||
|
||||
# Return the extra_body if it exists, otherwise empty dict
|
||||
return guardrail_config.get("extra_body", {})
|
||||
|
||||
return {}
|
||||
|
||||
def _validate_premium_user(self) -> bool:
|
||||
"""
|
||||
Returns True if the user is a premium user
|
||||
"""
|
||||
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
verbose_logger.warning(
|
||||
f"Trying to use premium guardrail without premium user {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
|
|
@ -86,12 +86,19 @@ class AporiaGuardrail(CustomGuardrail):
|
|||
return data
|
||||
|
||||
async def make_aporia_api_request(
|
||||
self, new_messages: List[dict], response_string: Optional[str] = None
|
||||
self,
|
||||
request_data: dict,
|
||||
new_messages: List[dict],
|
||||
response_string: Optional[str] = None,
|
||||
):
|
||||
data = await self.prepare_aporia_request(
|
||||
new_messages=new_messages, response_string=response_string
|
||||
)
|
||||
|
||||
data.update(
|
||||
self.get_guardrail_dynamic_request_body_params(request_data=request_data)
|
||||
)
|
||||
|
||||
_json_data = json.dumps(data)
|
||||
|
||||
"""
|
||||
|
@ -155,7 +162,9 @@ class AporiaGuardrail(CustomGuardrail):
|
|||
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
|
||||
if response_str is not None:
|
||||
await self.make_aporia_api_request(
|
||||
response_string=response_str, new_messages=data.get("messages", [])
|
||||
request_data=data,
|
||||
response_string=response_str,
|
||||
new_messages=data.get("messages", []),
|
||||
)
|
||||
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
|
@ -199,7 +208,10 @@ class AporiaGuardrail(CustomGuardrail):
|
|||
new_messages = self.transform_messages(messages=data["messages"])
|
||||
|
||||
if new_messages is not None:
|
||||
await self.make_aporia_api_request(new_messages=new_messages)
|
||||
await self.make_aporia_api_request(
|
||||
request_data=data,
|
||||
new_messages=new_messages,
|
||||
)
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
|
|
|
@ -149,7 +149,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
def _prepare_request(
|
||||
self,
|
||||
credentials,
|
||||
data: BedrockRequest,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
aws_region_name: str,
|
||||
extra_headers: Optional[dict] = None,
|
||||
|
@ -186,18 +186,23 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
):
|
||||
|
||||
credentials, aws_region_name = self._load_credentials()
|
||||
request_data: BedrockRequest = self.convert_to_bedrock_format(
|
||||
messages=kwargs.get("messages"), response=response
|
||||
bedrock_request_data: dict = dict(
|
||||
self.convert_to_bedrock_format(
|
||||
messages=kwargs.get("messages"), response=response
|
||||
)
|
||||
)
|
||||
bedrock_request_data.update(
|
||||
self.get_guardrail_dynamic_request_body_params(request_data=kwargs)
|
||||
)
|
||||
prepared_request = self._prepare_request(
|
||||
credentials=credentials,
|
||||
data=request_data,
|
||||
data=bedrock_request_data,
|
||||
optional_params=self.optional_params,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Bedrock AI request body: %s, url %s, headers: %s",
|
||||
request_data,
|
||||
bedrock_request_data,
|
||||
prepared_request.url,
|
||||
prepared_request.headers,
|
||||
)
|
||||
|
|
|
@ -48,10 +48,13 @@ class GuardrailsAI(CustomGuardrail):
|
|||
supported_event_hooks = [GuardrailEventHooks.post_call]
|
||||
super().__init__(supported_event_hooks=supported_event_hooks, **kwargs)
|
||||
|
||||
async def make_guardrails_ai_api_request(self, llm_output: str):
|
||||
async def make_guardrails_ai_api_request(self, llm_output: str, request_data: dict):
|
||||
from httpx import URL
|
||||
|
||||
data = {"llmOutput": llm_output}
|
||||
data = {
|
||||
"llmOutput": llm_output,
|
||||
**self.get_guardrail_dynamic_request_body_params(request_data=request_data),
|
||||
}
|
||||
_json_data = json.dumps(data)
|
||||
response = await litellm.module_level_aclient.post(
|
||||
url=str(
|
||||
|
@ -96,7 +99,9 @@ class GuardrailsAI(CustomGuardrail):
|
|||
|
||||
response_str: str = get_content_from_model_response(response)
|
||||
if response_str is not None and len(response_str) > 0:
|
||||
await self.make_guardrails_ai_api_request(llm_output=response_str)
|
||||
await self.make_guardrails_ai_api_request(
|
||||
llm_output=response_str, request_data=data
|
||||
)
|
||||
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
|
|
|
@ -216,14 +216,27 @@ class lakeraAI_Moderation(CustomGuardrail):
|
|||
"Skipping lakera prompt injection, no roles with messages found"
|
||||
)
|
||||
return
|
||||
data = {"input": lakera_input}
|
||||
_json_data = json.dumps(data)
|
||||
_data = {"input": lakera_input}
|
||||
_json_data = json.dumps(
|
||||
_data,
|
||||
**self.get_guardrail_dynamic_request_body_params(request_data=data),
|
||||
)
|
||||
elif "input" in data and isinstance(data["input"], str):
|
||||
text = data["input"]
|
||||
_json_data = json.dumps({"input": text})
|
||||
_json_data = json.dumps(
|
||||
{
|
||||
"input": text,
|
||||
**self.get_guardrail_dynamic_request_body_params(request_data=data),
|
||||
}
|
||||
)
|
||||
elif "input" in data and isinstance(data["input"], list):
|
||||
text = "\n".join(data["input"])
|
||||
_json_data = json.dumps({"input": text})
|
||||
_json_data = json.dumps(
|
||||
{
|
||||
"input": text,
|
||||
**self.get_guardrail_dynamic_request_body_params(request_data=data),
|
||||
}
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("Lakera AI Request Args %s", _json_data)
|
||||
|
||||
|
|
|
@ -132,6 +132,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
text: str,
|
||||
output_parse_pii: bool,
|
||||
presidio_config: Optional[PresidioPerRequestConfig],
|
||||
request_data: dict,
|
||||
) -> str:
|
||||
"""
|
||||
[TODO] make this more performant for high-throughput scenario
|
||||
|
@ -150,7 +151,11 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
if self.ad_hoc_recognizers is not None:
|
||||
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
|
||||
# End of constructing Request 1
|
||||
|
||||
analyze_payload.update(
|
||||
self.get_guardrail_dynamic_request_body_params(
|
||||
request_data=request_data
|
||||
)
|
||||
)
|
||||
redacted_text = None
|
||||
verbose_proxy_logger.debug(
|
||||
"Making request to: %s with payload: %s",
|
||||
|
@ -235,6 +240,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
text=m["content"],
|
||||
output_parse_pii=self.output_parse_pii,
|
||||
presidio_config=presidio_config,
|
||||
request_data=data,
|
||||
)
|
||||
)
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
@ -311,6 +317,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
text=text_str,
|
||||
output_parse_pii=False,
|
||||
presidio_config=presidio_config,
|
||||
request_data=kwargs,
|
||||
)
|
||||
) # need to pass separately b/c presidio has context window limits
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
|
|
@ -12,6 +12,14 @@ model_list:
|
|||
model: bedrock/*
|
||||
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "bedrock-pre-guard"
|
||||
litellm_params:
|
||||
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
|
||||
mode: "during_call"
|
||||
guardrailIdentifier: ff6ujrregl1q
|
||||
guardrailVersion: "DRAFT"
|
||||
|
||||
# for /files endpoints
|
||||
# For /fine_tuning/jobs endpoints
|
||||
finetune_settings:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from typing import Dict, List, Literal, Optional, TypedDict
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
@ -132,3 +132,7 @@ class BedrockContentItem(TypedDict, total=False):
|
|||
class BedrockRequest(TypedDict, total=False):
|
||||
source: Literal["INPUT", "OUTPUT"]
|
||||
content: List[BedrockContentItem]
|
||||
|
||||
|
||||
class DynamicGuardrailParams(TypedDict):
|
||||
extra_body: Dict[str, Any]
|
||||
|
|
145
tests/logging_callback_tests/test_custom_guardrail.py
Normal file
145
tests/logging_callback_tests/test_custom_guardrail.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
import io
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import asyncio
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import completion
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
|
||||
def test_get_guardrail_from_metadata():
|
||||
guardrail = CustomGuardrail(guardrail_name="test-guardrail")
|
||||
|
||||
# Test with empty metadata
|
||||
assert guardrail.get_guardrail_from_metadata({}) == []
|
||||
|
||||
# Test with guardrails in metadata
|
||||
data = {"metadata": {"guardrails": ["guardrail1", "guardrail2"]}}
|
||||
assert guardrail.get_guardrail_from_metadata(data) == ["guardrail1", "guardrail2"]
|
||||
|
||||
# Test with dict guardrails
|
||||
data = {
|
||||
"metadata": {
|
||||
"guardrails": [{"test-guardrail": {"extra_body": {"key": "value"}}}]
|
||||
}
|
||||
}
|
||||
assert guardrail.get_guardrail_from_metadata(data) == [
|
||||
{"test-guardrail": {"extra_body": {"key": "value"}}}
|
||||
]
|
||||
|
||||
|
||||
def test_guardrail_is_in_requested_guardrails():
|
||||
guardrail = CustomGuardrail(guardrail_name="test-guardrail")
|
||||
|
||||
# Test with string list
|
||||
assert (
|
||||
guardrail._guardrail_is_in_requested_guardrails(["test-guardrail", "other"])
|
||||
== True
|
||||
)
|
||||
assert guardrail._guardrail_is_in_requested_guardrails(["other"]) == False
|
||||
|
||||
# Test with dict list
|
||||
assert (
|
||||
guardrail._guardrail_is_in_requested_guardrails(
|
||||
[{"test-guardrail": {"extra_body": {"extra_key": "extra_value"}}}]
|
||||
)
|
||||
== True
|
||||
)
|
||||
assert (
|
||||
guardrail._guardrail_is_in_requested_guardrails(
|
||||
[
|
||||
{
|
||||
"other-guardrail": {"extra_body": {"extra_key": "extra_value"}},
|
||||
"test-guardrail": {"extra_body": {"extra_key": "extra_value"}},
|
||||
}
|
||||
]
|
||||
)
|
||||
== True
|
||||
)
|
||||
assert (
|
||||
guardrail._guardrail_is_in_requested_guardrails(
|
||||
[{"other-guardrail": {"extra_body": {"extra_key": "extra_value"}}}]
|
||||
)
|
||||
== False
|
||||
)
|
||||
|
||||
|
||||
def test_should_run_guardrail():
|
||||
guardrail = CustomGuardrail(
|
||||
guardrail_name="test-guardrail", event_hook=GuardrailEventHooks.pre_call
|
||||
)
|
||||
|
||||
# Test matching event hook and guardrail
|
||||
assert (
|
||||
guardrail.should_run_guardrail(
|
||||
{"metadata": {"guardrails": ["test-guardrail"]}},
|
||||
GuardrailEventHooks.pre_call,
|
||||
)
|
||||
== True
|
||||
)
|
||||
|
||||
# Test non-matching event hook
|
||||
assert (
|
||||
guardrail.should_run_guardrail(
|
||||
{"metadata": {"guardrails": ["test-guardrail"]}},
|
||||
GuardrailEventHooks.during_call,
|
||||
)
|
||||
== False
|
||||
)
|
||||
|
||||
# Test guardrail not in requested list
|
||||
assert (
|
||||
guardrail.should_run_guardrail(
|
||||
{"metadata": {"guardrails": ["other-guardrail"]}},
|
||||
GuardrailEventHooks.pre_call,
|
||||
)
|
||||
== False
|
||||
)
|
||||
|
||||
|
||||
def test_get_guardrail_dynamic_request_body_params():
|
||||
guardrail = CustomGuardrail(guardrail_name="test-guardrail")
|
||||
|
||||
# Test with no extra_body
|
||||
data = {"metadata": {"guardrails": [{"test-guardrail": {}}]}}
|
||||
assert guardrail.get_guardrail_dynamic_request_body_params(data) == {}
|
||||
|
||||
# Test with extra_body
|
||||
data = {
|
||||
"metadata": {
|
||||
"guardrails": [{"test-guardrail": {"extra_body": {"key": "value"}}}]
|
||||
}
|
||||
}
|
||||
assert guardrail.get_guardrail_dynamic_request_body_params(data) == {"key": "value"}
|
||||
|
||||
# Test with non-matching guardrail
|
||||
data = {
|
||||
"metadata": {
|
||||
"guardrails": [{"other-guardrail": {"extra_body": {"key": "value"}}}]
|
||||
}
|
||||
}
|
||||
assert guardrail.get_guardrail_dynamic_request_body_params(data) == {}
|
Loading…
Add table
Add a link
Reference in a new issue