mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(feat) Support Dynamic Params for guardrails
(#7415)
* update CustomGuardrail * unit test custom guardrails * add dynamic params for aporia * add dynamic params to bedrock guard * add dynamic params for all guardrails * fix linting * fix should_run_guardrail * _validate_premium_user * update guardrail doc * doc update * update code q * should_run_guardrail
This commit is contained in:
parent
77fa751639
commit
0ce5f9fe58
10 changed files with 411 additions and 21 deletions
|
@ -114,6 +114,88 @@ curl -i http://localhost:4000/v1/chat/completions \
|
||||||
|
|
||||||
|
|
||||||
## Advanced
|
## Advanced
|
||||||
|
|
||||||
|
### ✨ Pass additional parameters to guardrail
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
✨ This is an Enterprise only feature [Contact us to get a free trial](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Use this to pass additional parameters to the guardrail API call. e.g. things like success threshold. **[See `guardrails` spec for more details](#spec-guardrails-parameter)**
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
Set `guardrails={"aporia-pre-guard": {"extra_body": {"success_threshold": 0.9}}}` to pass additional parameters to the guardrail
|
||||||
|
|
||||||
|
In this example `success_threshold=0.9` is passed to the `aporia-pre-guard` guardrail request body
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="anything",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_body={
|
||||||
|
"guardrails": [
|
||||||
|
"aporia-pre-guard": {
|
||||||
|
"extra_body": {
|
||||||
|
"success_threshold": 0.9
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
|
||||||
|
<TabItem value="Curl" label="Curl Request">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"guardrails": [
|
||||||
|
"aporia-pre-guard": {
|
||||||
|
"extra_body": {
|
||||||
|
"success_threshold": 0.9
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
### ✨ Control Guardrails per Project (API Key)
|
### ✨ Control Guardrails per Project (API Key)
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
@ -253,3 +335,42 @@ Expected response
|
||||||
"guardrails": ["aporia-pre-guard", "aporia-post-guard"]
|
"guardrails": ["aporia-pre-guard", "aporia-post-guard"]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Spec: `guardrails` Parameter
|
||||||
|
|
||||||
|
The `guardrails` parameter can be passed to any LiteLLM Proxy endpoint (`/chat/completions`, `/completions`, `/embeddings`).
|
||||||
|
|
||||||
|
### Format Options
|
||||||
|
|
||||||
|
1. Simple List Format:
|
||||||
|
```python
|
||||||
|
"guardrails": [
|
||||||
|
"aporia-pre-guard",
|
||||||
|
"aporia-post-guard"
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Advanced Dictionary Format:
|
||||||
|
|
||||||
|
In this format the dictionary key is `guardrail_name` you want to run
|
||||||
|
```python
|
||||||
|
"guardrails": {
|
||||||
|
"aporia-pre-guard": {
|
||||||
|
"extra_body": {
|
||||||
|
"success_threshold": 0.9,
|
||||||
|
"other_param": "value"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Type Definition
|
||||||
|
```python
|
||||||
|
guardrails: Union[
|
||||||
|
List[str], # Simple list of guardrail names
|
||||||
|
Dict[str, DynamicGuardrailParams] # Advanced configuration
|
||||||
|
]
|
||||||
|
|
||||||
|
class DynamicGuardrailParams:
|
||||||
|
extra_body: Dict[str, Any] # Additional parameters for the guardrail
|
||||||
|
```
|
|
@ -1,8 +1,8 @@
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.types.guardrails import GuardrailEventHooks
|
from litellm.types.guardrails import DynamicGuardrailParams, GuardrailEventHooks
|
||||||
|
|
||||||
|
|
||||||
class CustomGuardrail(CustomLogger):
|
class CustomGuardrail(CustomLogger):
|
||||||
|
@ -26,9 +26,31 @@ class CustomGuardrail(CustomLogger):
|
||||||
)
|
)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
def get_guardrail_from_metadata(
|
||||||
|
self, data: dict
|
||||||
|
) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]:
|
||||||
|
"""
|
||||||
|
Returns the guardrail(s) to be run from the metadata
|
||||||
|
"""
|
||||||
metadata = data.get("metadata") or {}
|
metadata = data.get("metadata") or {}
|
||||||
requested_guardrails = metadata.get("guardrails") or []
|
requested_guardrails = metadata.get("guardrails") or []
|
||||||
|
return requested_guardrails
|
||||||
|
|
||||||
|
def _guardrail_is_in_requested_guardrails(
|
||||||
|
self,
|
||||||
|
requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]],
|
||||||
|
) -> bool:
|
||||||
|
for _guardrail in requested_guardrails:
|
||||||
|
if isinstance(_guardrail, dict):
|
||||||
|
if self.guardrail_name in _guardrail:
|
||||||
|
return True
|
||||||
|
elif isinstance(_guardrail, str):
|
||||||
|
if self.guardrail_name == _guardrail:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
||||||
|
requested_guardrails = self.get_guardrail_from_metadata(data)
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s",
|
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s",
|
||||||
|
@ -40,7 +62,7 @@ class CustomGuardrail(CustomLogger):
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.event_hook
|
self.event_hook
|
||||||
and self.guardrail_name not in requested_guardrails
|
and not self._guardrail_is_in_requested_guardrails(requested_guardrails)
|
||||||
and event_type.value != "logging_only"
|
and event_type.value != "logging_only"
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
@ -49,3 +71,51 @@ class CustomGuardrail(CustomLogger):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Returns `extra_body` to be added to the request body for the Guardrail API call
|
||||||
|
|
||||||
|
Use this to pass dynamic params to the guardrail API call - eg. success_threshold, failure_threshold, etc.
|
||||||
|
|
||||||
|
```
|
||||||
|
[{"lakera_guard": {"extra_body": {"foo": "bar"}}}]
|
||||||
|
```
|
||||||
|
|
||||||
|
Will return: for guardrail=`lakera-guard`:
|
||||||
|
{
|
||||||
|
"foo": "bar"
|
||||||
|
}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_data: The original `request_data` passed to LiteLLM Proxy
|
||||||
|
"""
|
||||||
|
requested_guardrails = self.get_guardrail_from_metadata(request_data)
|
||||||
|
|
||||||
|
# Look for the guardrail configuration matching self.guardrail_name
|
||||||
|
for guardrail in requested_guardrails:
|
||||||
|
if isinstance(guardrail, dict) and self.guardrail_name in guardrail:
|
||||||
|
# Get the configuration for this guardrail
|
||||||
|
guardrail_config: DynamicGuardrailParams = DynamicGuardrailParams(
|
||||||
|
**guardrail[self.guardrail_name]
|
||||||
|
)
|
||||||
|
if self._validate_premium_user() is not True:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Return the extra_body if it exists, otherwise empty dict
|
||||||
|
return guardrail_config.get("extra_body", {})
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _validate_premium_user(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the user is a premium user
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
||||||
|
|
||||||
|
if premium_user is not True:
|
||||||
|
verbose_logger.warning(
|
||||||
|
f"Trying to use premium guardrail without premium user {CommonProxyErrors.not_premium_user.value}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
|
@ -86,12 +86,19 @@ class AporiaGuardrail(CustomGuardrail):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def make_aporia_api_request(
|
async def make_aporia_api_request(
|
||||||
self, new_messages: List[dict], response_string: Optional[str] = None
|
self,
|
||||||
|
request_data: dict,
|
||||||
|
new_messages: List[dict],
|
||||||
|
response_string: Optional[str] = None,
|
||||||
):
|
):
|
||||||
data = await self.prepare_aporia_request(
|
data = await self.prepare_aporia_request(
|
||||||
new_messages=new_messages, response_string=response_string
|
new_messages=new_messages, response_string=response_string
|
||||||
)
|
)
|
||||||
|
|
||||||
|
data.update(
|
||||||
|
self.get_guardrail_dynamic_request_body_params(request_data=request_data)
|
||||||
|
)
|
||||||
|
|
||||||
_json_data = json.dumps(data)
|
_json_data = json.dumps(data)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -155,7 +162,9 @@ class AporiaGuardrail(CustomGuardrail):
|
||||||
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
|
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
|
||||||
if response_str is not None:
|
if response_str is not None:
|
||||||
await self.make_aporia_api_request(
|
await self.make_aporia_api_request(
|
||||||
response_string=response_str, new_messages=data.get("messages", [])
|
request_data=data,
|
||||||
|
response_string=response_str,
|
||||||
|
new_messages=data.get("messages", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
add_guardrail_to_applied_guardrails_header(
|
add_guardrail_to_applied_guardrails_header(
|
||||||
|
@ -199,7 +208,10 @@ class AporiaGuardrail(CustomGuardrail):
|
||||||
new_messages = self.transform_messages(messages=data["messages"])
|
new_messages = self.transform_messages(messages=data["messages"])
|
||||||
|
|
||||||
if new_messages is not None:
|
if new_messages is not None:
|
||||||
await self.make_aporia_api_request(new_messages=new_messages)
|
await self.make_aporia_api_request(
|
||||||
|
request_data=data,
|
||||||
|
new_messages=new_messages,
|
||||||
|
)
|
||||||
add_guardrail_to_applied_guardrails_header(
|
add_guardrail_to_applied_guardrails_header(
|
||||||
request_data=data, guardrail_name=self.guardrail_name
|
request_data=data, guardrail_name=self.guardrail_name
|
||||||
)
|
)
|
||||||
|
|
|
@ -149,7 +149,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||||
def _prepare_request(
|
def _prepare_request(
|
||||||
self,
|
self,
|
||||||
credentials,
|
credentials,
|
||||||
data: BedrockRequest,
|
data: dict,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
aws_region_name: str,
|
aws_region_name: str,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
|
@ -186,18 +186,23 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||||
):
|
):
|
||||||
|
|
||||||
credentials, aws_region_name = self._load_credentials()
|
credentials, aws_region_name = self._load_credentials()
|
||||||
request_data: BedrockRequest = self.convert_to_bedrock_format(
|
bedrock_request_data: dict = dict(
|
||||||
|
self.convert_to_bedrock_format(
|
||||||
messages=kwargs.get("messages"), response=response
|
messages=kwargs.get("messages"), response=response
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
bedrock_request_data.update(
|
||||||
|
self.get_guardrail_dynamic_request_body_params(request_data=kwargs)
|
||||||
|
)
|
||||||
prepared_request = self._prepare_request(
|
prepared_request = self._prepare_request(
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
data=request_data,
|
data=bedrock_request_data,
|
||||||
optional_params=self.optional_params,
|
optional_params=self.optional_params,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"Bedrock AI request body: %s, url %s, headers: %s",
|
"Bedrock AI request body: %s, url %s, headers: %s",
|
||||||
request_data,
|
bedrock_request_data,
|
||||||
prepared_request.url,
|
prepared_request.url,
|
||||||
prepared_request.headers,
|
prepared_request.headers,
|
||||||
)
|
)
|
||||||
|
|
|
@ -48,10 +48,13 @@ class GuardrailsAI(CustomGuardrail):
|
||||||
supported_event_hooks = [GuardrailEventHooks.post_call]
|
supported_event_hooks = [GuardrailEventHooks.post_call]
|
||||||
super().__init__(supported_event_hooks=supported_event_hooks, **kwargs)
|
super().__init__(supported_event_hooks=supported_event_hooks, **kwargs)
|
||||||
|
|
||||||
async def make_guardrails_ai_api_request(self, llm_output: str):
|
async def make_guardrails_ai_api_request(self, llm_output: str, request_data: dict):
|
||||||
from httpx import URL
|
from httpx import URL
|
||||||
|
|
||||||
data = {"llmOutput": llm_output}
|
data = {
|
||||||
|
"llmOutput": llm_output,
|
||||||
|
**self.get_guardrail_dynamic_request_body_params(request_data=request_data),
|
||||||
|
}
|
||||||
_json_data = json.dumps(data)
|
_json_data = json.dumps(data)
|
||||||
response = await litellm.module_level_aclient.post(
|
response = await litellm.module_level_aclient.post(
|
||||||
url=str(
|
url=str(
|
||||||
|
@ -96,7 +99,9 @@ class GuardrailsAI(CustomGuardrail):
|
||||||
|
|
||||||
response_str: str = get_content_from_model_response(response)
|
response_str: str = get_content_from_model_response(response)
|
||||||
if response_str is not None and len(response_str) > 0:
|
if response_str is not None and len(response_str) > 0:
|
||||||
await self.make_guardrails_ai_api_request(llm_output=response_str)
|
await self.make_guardrails_ai_api_request(
|
||||||
|
llm_output=response_str, request_data=data
|
||||||
|
)
|
||||||
|
|
||||||
add_guardrail_to_applied_guardrails_header(
|
add_guardrail_to_applied_guardrails_header(
|
||||||
request_data=data, guardrail_name=self.guardrail_name
|
request_data=data, guardrail_name=self.guardrail_name
|
||||||
|
|
|
@ -216,14 +216,27 @@ class lakeraAI_Moderation(CustomGuardrail):
|
||||||
"Skipping lakera prompt injection, no roles with messages found"
|
"Skipping lakera prompt injection, no roles with messages found"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
data = {"input": lakera_input}
|
_data = {"input": lakera_input}
|
||||||
_json_data = json.dumps(data)
|
_json_data = json.dumps(
|
||||||
|
_data,
|
||||||
|
**self.get_guardrail_dynamic_request_body_params(request_data=data),
|
||||||
|
)
|
||||||
elif "input" in data and isinstance(data["input"], str):
|
elif "input" in data and isinstance(data["input"], str):
|
||||||
text = data["input"]
|
text = data["input"]
|
||||||
_json_data = json.dumps({"input": text})
|
_json_data = json.dumps(
|
||||||
|
{
|
||||||
|
"input": text,
|
||||||
|
**self.get_guardrail_dynamic_request_body_params(request_data=data),
|
||||||
|
}
|
||||||
|
)
|
||||||
elif "input" in data and isinstance(data["input"], list):
|
elif "input" in data and isinstance(data["input"], list):
|
||||||
text = "\n".join(data["input"])
|
text = "\n".join(data["input"])
|
||||||
_json_data = json.dumps({"input": text})
|
_json_data = json.dumps(
|
||||||
|
{
|
||||||
|
"input": text,
|
||||||
|
**self.get_guardrail_dynamic_request_body_params(request_data=data),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug("Lakera AI Request Args %s", _json_data)
|
verbose_proxy_logger.debug("Lakera AI Request Args %s", _json_data)
|
||||||
|
|
||||||
|
|
|
@ -132,6 +132,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
text: str,
|
text: str,
|
||||||
output_parse_pii: bool,
|
output_parse_pii: bool,
|
||||||
presidio_config: Optional[PresidioPerRequestConfig],
|
presidio_config: Optional[PresidioPerRequestConfig],
|
||||||
|
request_data: dict,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
[TODO] make this more performant for high-throughput scenario
|
[TODO] make this more performant for high-throughput scenario
|
||||||
|
@ -150,7 +151,11 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
if self.ad_hoc_recognizers is not None:
|
if self.ad_hoc_recognizers is not None:
|
||||||
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
|
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
|
||||||
# End of constructing Request 1
|
# End of constructing Request 1
|
||||||
|
analyze_payload.update(
|
||||||
|
self.get_guardrail_dynamic_request_body_params(
|
||||||
|
request_data=request_data
|
||||||
|
)
|
||||||
|
)
|
||||||
redacted_text = None
|
redacted_text = None
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"Making request to: %s with payload: %s",
|
"Making request to: %s with payload: %s",
|
||||||
|
@ -235,6 +240,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
text=m["content"],
|
text=m["content"],
|
||||||
output_parse_pii=self.output_parse_pii,
|
output_parse_pii=self.output_parse_pii,
|
||||||
presidio_config=presidio_config,
|
presidio_config=presidio_config,
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
responses = await asyncio.gather(*tasks)
|
responses = await asyncio.gather(*tasks)
|
||||||
|
@ -311,6 +317,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
text=text_str,
|
text=text_str,
|
||||||
output_parse_pii=False,
|
output_parse_pii=False,
|
||||||
presidio_config=presidio_config,
|
presidio_config=presidio_config,
|
||||||
|
request_data=kwargs,
|
||||||
)
|
)
|
||||||
) # need to pass separately b/c presidio has context window limits
|
) # need to pass separately b/c presidio has context window limits
|
||||||
responses = await asyncio.gather(*tasks)
|
responses = await asyncio.gather(*tasks)
|
||||||
|
|
|
@ -12,6 +12,14 @@ model_list:
|
||||||
model: bedrock/*
|
model: bedrock/*
|
||||||
|
|
||||||
|
|
||||||
|
guardrails:
|
||||||
|
- guardrail_name: "bedrock-pre-guard"
|
||||||
|
litellm_params:
|
||||||
|
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
|
||||||
|
mode: "during_call"
|
||||||
|
guardrailIdentifier: ff6ujrregl1q
|
||||||
|
guardrailVersion: "DRAFT"
|
||||||
|
|
||||||
# for /files endpoints
|
# for /files endpoints
|
||||||
# For /fine_tuning/jobs endpoints
|
# For /fine_tuning/jobs endpoints
|
||||||
finetune_settings:
|
finetune_settings:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Literal, Optional, TypedDict
|
from typing import Any, Dict, List, Literal, Optional, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
@ -132,3 +132,7 @@ class BedrockContentItem(TypedDict, total=False):
|
||||||
class BedrockRequest(TypedDict, total=False):
|
class BedrockRequest(TypedDict, total=False):
|
||||||
source: Literal["INPUT", "OUTPUT"]
|
source: Literal["INPUT", "OUTPUT"]
|
||||||
content: List[BedrockContentItem]
|
content: List[BedrockContentItem]
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicGuardrailParams(TypedDict):
|
||||||
|
extra_body: Dict[str, Any]
|
||||||
|
|
145
tests/logging_callback_tests/test_custom_guardrail.py
Normal file
145
tests/logging_callback_tests/test_custom_guardrail.py
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.abspath("../.."))
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import completion
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.caching.caching import DualCache
|
||||||
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_guardrail_from_metadata():
|
||||||
|
guardrail = CustomGuardrail(guardrail_name="test-guardrail")
|
||||||
|
|
||||||
|
# Test with empty metadata
|
||||||
|
assert guardrail.get_guardrail_from_metadata({}) == []
|
||||||
|
|
||||||
|
# Test with guardrails in metadata
|
||||||
|
data = {"metadata": {"guardrails": ["guardrail1", "guardrail2"]}}
|
||||||
|
assert guardrail.get_guardrail_from_metadata(data) == ["guardrail1", "guardrail2"]
|
||||||
|
|
||||||
|
# Test with dict guardrails
|
||||||
|
data = {
|
||||||
|
"metadata": {
|
||||||
|
"guardrails": [{"test-guardrail": {"extra_body": {"key": "value"}}}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert guardrail.get_guardrail_from_metadata(data) == [
|
||||||
|
{"test-guardrail": {"extra_body": {"key": "value"}}}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_guardrail_is_in_requested_guardrails():
|
||||||
|
guardrail = CustomGuardrail(guardrail_name="test-guardrail")
|
||||||
|
|
||||||
|
# Test with string list
|
||||||
|
assert (
|
||||||
|
guardrail._guardrail_is_in_requested_guardrails(["test-guardrail", "other"])
|
||||||
|
== True
|
||||||
|
)
|
||||||
|
assert guardrail._guardrail_is_in_requested_guardrails(["other"]) == False
|
||||||
|
|
||||||
|
# Test with dict list
|
||||||
|
assert (
|
||||||
|
guardrail._guardrail_is_in_requested_guardrails(
|
||||||
|
[{"test-guardrail": {"extra_body": {"extra_key": "extra_value"}}}]
|
||||||
|
)
|
||||||
|
== True
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
guardrail._guardrail_is_in_requested_guardrails(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"other-guardrail": {"extra_body": {"extra_key": "extra_value"}},
|
||||||
|
"test-guardrail": {"extra_body": {"extra_key": "extra_value"}},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
== True
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
guardrail._guardrail_is_in_requested_guardrails(
|
||||||
|
[{"other-guardrail": {"extra_body": {"extra_key": "extra_value"}}}]
|
||||||
|
)
|
||||||
|
== False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_run_guardrail():
|
||||||
|
guardrail = CustomGuardrail(
|
||||||
|
guardrail_name="test-guardrail", event_hook=GuardrailEventHooks.pre_call
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test matching event hook and guardrail
|
||||||
|
assert (
|
||||||
|
guardrail.should_run_guardrail(
|
||||||
|
{"metadata": {"guardrails": ["test-guardrail"]}},
|
||||||
|
GuardrailEventHooks.pre_call,
|
||||||
|
)
|
||||||
|
== True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test non-matching event hook
|
||||||
|
assert (
|
||||||
|
guardrail.should_run_guardrail(
|
||||||
|
{"metadata": {"guardrails": ["test-guardrail"]}},
|
||||||
|
GuardrailEventHooks.during_call,
|
||||||
|
)
|
||||||
|
== False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test guardrail not in requested list
|
||||||
|
assert (
|
||||||
|
guardrail.should_run_guardrail(
|
||||||
|
{"metadata": {"guardrails": ["other-guardrail"]}},
|
||||||
|
GuardrailEventHooks.pre_call,
|
||||||
|
)
|
||||||
|
== False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_guardrail_dynamic_request_body_params():
|
||||||
|
guardrail = CustomGuardrail(guardrail_name="test-guardrail")
|
||||||
|
|
||||||
|
# Test with no extra_body
|
||||||
|
data = {"metadata": {"guardrails": [{"test-guardrail": {}}]}}
|
||||||
|
assert guardrail.get_guardrail_dynamic_request_body_params(data) == {}
|
||||||
|
|
||||||
|
# Test with extra_body
|
||||||
|
data = {
|
||||||
|
"metadata": {
|
||||||
|
"guardrails": [{"test-guardrail": {"extra_body": {"key": "value"}}}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert guardrail.get_guardrail_dynamic_request_body_params(data) == {"key": "value"}
|
||||||
|
|
||||||
|
# Test with non-matching guardrail
|
||||||
|
data = {
|
||||||
|
"metadata": {
|
||||||
|
"guardrails": [{"other-guardrail": {"extra_body": {"key": "value"}}}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert guardrail.get_guardrail_dynamic_request_body_params(data) == {}
|
Loading…
Add table
Add a link
Reference in a new issue