(Feat) - allow setting default_on guardrails (#7973)

* test_default_on_guardrail

* update debug on custom guardrail

* refactor guardrails init

* guardrail registry

* allow switching guardrails default_on

* fix circle import issue

* fix bedrock applying guardrails where content is a list

* fix unused import

* docs default on guardrail

* docs fix per api key
This commit is contained in:
Ishaan Jaff 2025-01-24 10:14:05 -08:00 committed by GitHub
parent 04401c7080
commit d1bc955d97
10 changed files with 292 additions and 325 deletions

View file

@ -1,154 +0,0 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Logging GCS, s3 Buckets
LiteLLM Supports Logging to the following Cloud Buckets
- (Enterprise) ✨ [Google Cloud Storage Buckets](#logging-proxy-inputoutput-to-google-cloud-storage-buckets)
- (Free OSS) [Amazon s3 Buckets](#logging-proxy-inputoutput---s3-buckets)
## Google Cloud Storage Buckets
Log LLM Logs to [Google Cloud Storage Buckets](https://cloud.google.com/storage?hl=en)
:::info
✨ This is an Enterprise only feature [Get Started with Enterprise here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
| Property | Details |
|----------|---------|
| Description | Log LLM Input/Output to cloud storage buckets |
| Load Test Benchmarks | [Benchmarks](https://docs.litellm.ai/docs/benchmarks) |
| Google Docs on Cloud Storage | [Google Cloud Storage](https://cloud.google.com/storage?hl=en) |
### Usage
1. Add `gcs_bucket` to LiteLLM Config.yaml
```yaml
model_list:
- litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key
model: openai/my-fake-model
model_name: fake-openai-endpoint
litellm_settings:
callbacks: ["gcs_bucket"] # 👈 KEY CHANGE # 👈 KEY CHANGE
```
2. Set required env variables
```shell
GCS_BUCKET_NAME="<your-gcs-bucket-name>"
GCS_PATH_SERVICE_ACCOUNT="/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
```
3. Start Proxy
```
litellm --config /path/to/config.yaml
```
4. Test it!
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "fake-openai-endpoint",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
### Expected Logs on GCS Buckets
<Image img={require('../../img/gcs_bucket.png')} />
### Fields Logged on GCS Buckets
[**The standard logging object is logged on GCS Bucket**](../proxy/logging)
### Getting `service_account.json` from Google Cloud Console
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
2. Search for IAM & Admin
3. Click on Service Accounts
4. Select a Service Account
5. Click on 'Keys' -> Add Key -> Create New Key -> JSON
6. Save the JSON file and add the path to `GCS_PATH_SERVICE_ACCOUNT`
## s3 Buckets
We will use the `--config` to set
- `litellm.success_callback = ["s3"]`
This will log all successfull LLM calls to s3 Bucket
**Step 1** Set AWS Credentials in .env
```shell
AWS_ACCESS_KEY_ID = ""
AWS_SECRET_ACCESS_KEY = ""
AWS_REGION_NAME = ""
```
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
success_callback: ["s3"]
s3_callback_params:
s3_bucket_name: logs-bucket-litellm # AWS Bucket Name for S3
s3_region_name: us-west-2 # AWS Region Name for S3
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
s3_path: my-test-path # [OPTIONAL] set path in bucket you want to write logs to
s3_endpoint_url: https://s3.amazonaws.com # [OPTIONAL] S3 endpoint URL, if you want to use Backblaze/cloudflare s3 buckets
```
**Step 3**: Start the proxy, make a test request
Start proxy
```shell
litellm --config config.yaml --debug
```
Test Request
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "Azure OpenAI GPT-4 East",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}'
```
Your logs should be available on the specified s3 Bucket

View file

@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Quick Start
# Guardrails - Quick Start
Setup Prompt Injection Detection, PII Masking on LiteLLM Proxy (AI Gateway)
@ -121,6 +121,47 @@ curl -i http://localhost:4000/v1/chat/completions \
</Tabs>
## **Default On Guardrails**
Set `default_on: true` in your guardrail config to run the guardrail on every request. This is useful if you want to run a guardrail on every request without the user having to specify it.
```yaml
guardrails:
- guardrail_name: "aporia-pre-guard"
litellm_params:
guardrail: aporia
mode: "pre_call"
default_on: true
```
**Test Request**
In this request, the guardrail `aporia-pre-guard` will run on every request because `default_on: true` is set.
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "hi my email is ishaan@berri.ai"}
]
}'
```
**Expected response**
Your response headers will incude `x-litellm-applied-guardrails` with the guardrail applied
```
x-litellm-applied-guardrails: aporia-pre-guard
```
## **Using Guardrails Client Side**
### Test yourself **(OSS)**
@ -349,7 +390,7 @@ Monitor which guardrails were executed and whether they passed or failed. e.g. g
### ✨ Control Guardrails per Project (API Key)
### ✨ Control Guardrails per API Key
:::info
@ -357,7 +398,7 @@ Monitor which guardrails were executed and whether they passed or failed. e.g. g
:::
Use this to control what guardrails run per project. In this tutorial we only want the following guardrails to run for 1 project (API Key)
Use this to control what guardrails run per API Key. In this tutorial we only want the following guardrails to run for 1 API Key
- `guardrails`: ["aporia-pre-guard", "aporia-post-guard"]
**Step 1** Create Key with guardrail settings
@ -484,6 +525,7 @@ guardrails:
mode: string # Required: One of "pre_call", "post_call", "during_call", "logging_only"
api_key: string # Required: API key for the guardrail service
api_base: string # Optional: Base URL for the guardrail service
default_on: boolean # Optional: Default False. When set to True, will run on every request, does not need client to specify guardrail in request
guardrail_info: # Optional[Dict]: Additional information about the guardrail
```

View file

@ -13,11 +13,22 @@ class CustomGuardrail(CustomLogger):
guardrail_name: Optional[str] = None,
supported_event_hooks: Optional[List[GuardrailEventHooks]] = None,
event_hook: Optional[GuardrailEventHooks] = None,
default_on: bool = False,
**kwargs,
):
"""
Initialize the CustomGuardrail class
Args:
guardrail_name: The name of the guardrail. This is the name used in your requests.
supported_event_hooks: The event hooks that the guardrail supports
event_hook: The event hook to run the guardrail on
default_on: If True, the guardrail will be run by default on all requests
"""
self.guardrail_name = guardrail_name
self.supported_event_hooks = supported_event_hooks
self.event_hook: Optional[GuardrailEventHooks] = event_hook
self.default_on: bool = default_on
if supported_event_hooks:
## validate event_hook is in supported_event_hooks
@ -51,16 +62,25 @@ class CustomGuardrail(CustomLogger):
return False
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
"""
Returns True if the guardrail should be run on the event_type
"""
requested_guardrails = self.get_guardrail_from_metadata(data)
verbose_logger.debug(
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s",
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s self.default_on= %s",
self.guardrail_name,
event_type,
self.event_hook,
requested_guardrails,
self.default_on,
)
if self.default_on is True:
if self._event_hook_is_event_type(event_type):
return True
return False
if (
self.event_hook
and not self._guardrail_is_in_requested_guardrails(requested_guardrails)
@ -73,6 +93,15 @@ class CustomGuardrail(CustomLogger):
return True
def _event_hook_is_event_type(self, event_type: GuardrailEventHooks) -> bool:
"""
Returns True if the event_hook is the same as the event_type
eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True
eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False
"""
return self.event_hook == event_type.value
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
"""
Returns `extra_body` to be added to the request body for the Guardrail API call

View file

@ -13,7 +13,7 @@ sys.path.insert(
) # Adds the parent directory to the system path
import json
import sys
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, List, Literal, Optional, Union
from fastapi import HTTPException
@ -23,6 +23,9 @@ from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
@ -36,6 +39,7 @@ from litellm.types.guardrails import (
BedrockTextContent,
GuardrailEventHooks,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
GUARDRAIL_NAME = "bedrock"
@ -62,7 +66,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
def convert_to_bedrock_format(
self,
messages: Optional[List[Dict[str, str]]] = None,
messages: Optional[List[AllMessageValues]] = None,
response: Optional[Union[Any, ModelResponse]] = None,
) -> BedrockRequest:
bedrock_request: BedrockRequest = BedrockRequest(source="INPUT")
@ -70,10 +74,10 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
if messages:
for message in messages:
content = message.get("content")
if isinstance(content, str):
bedrock_content_item = BedrockContentItem(
text=BedrockTextContent(text=content)
text=BedrockTextContent(
text=convert_content_list_to_str(message=message)
)
)
bedrock_request_content.append(bedrock_content_item)

View file

@ -0,0 +1,114 @@
# litellm/proxy/guardrails/guardrail_initializers.py
import litellm
from litellm.types.guardrails import *
def initialize_aporia(litellm_params, guardrail):
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import AporiaGuardrail
_aporia_callback = AporiaGuardrail(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_aporia_callback)
def initialize_bedrock(litellm_params, guardrail):
from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import (
BedrockGuardrail,
)
_bedrock_callback = BedrockGuardrail(
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
guardrailIdentifier=litellm_params["guardrailIdentifier"],
guardrailVersion=litellm_params["guardrailVersion"],
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_bedrock_callback)
def initialize_lakera(litellm_params, guardrail):
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import lakeraAI_Moderation
_lakera_callback = lakeraAI_Moderation(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
category_thresholds=litellm_params.get("category_thresholds"),
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_lakera_callback)
def initialize_aim(litellm_params, guardrail):
from litellm.proxy.guardrails.guardrail_hooks.aim import AimGuardrail
_aim_callback = AimGuardrail(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_aim_callback)
def initialize_presidio(litellm_params, guardrail):
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
_OPTIONAL_PresidioPIIMasking,
)
_presidio_callback = _OPTIONAL_PresidioPIIMasking(
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
output_parse_pii=litellm_params["output_parse_pii"],
presidio_ad_hoc_recognizers=litellm_params["presidio_ad_hoc_recognizers"],
mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_presidio_callback)
if litellm_params["output_parse_pii"]:
_success_callback = _OPTIONAL_PresidioPIIMasking(
output_parse_pii=True,
guardrail_name=guardrail["guardrail_name"],
event_hook=GuardrailEventHooks.post_call.value,
presidio_ad_hoc_recognizers=litellm_params["presidio_ad_hoc_recognizers"],
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_success_callback)
def initialize_hide_secrets(litellm_params, guardrail):
from enterprise.enterprise_hooks.secret_detection import _ENTERPRISE_SecretDetection
_secret_detection_object = _ENTERPRISE_SecretDetection(
detect_secrets_config=litellm_params.get("detect_secrets_config"),
event_hook=litellm_params["mode"],
guardrail_name=guardrail["guardrail_name"],
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_secret_detection_object)
def initialize_guardrails_ai(litellm_params, guardrail):
from litellm.proxy.guardrails.guardrail_hooks.guardrails_ai import GuardrailsAI
_guard_name = litellm_params.get("guard_name")
if not _guard_name:
raise Exception(
"GuardrailsAIException - Please pass the Guardrails AI guard name via 'litellm_params::guard_name'"
)
_guardrails_ai_callback = GuardrailsAI(
api_base=litellm_params.get("api_base"),
guard_name=_guard_name,
guardrail_name=SupportedGuardrailIntegrations.GURDRAILS_AI.value,
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_guardrails_ai_callback)

View file

@ -0,0 +1,23 @@
# litellm/proxy/guardrails/guardrail_registry.py
from litellm.types.guardrails import SupportedGuardrailIntegrations
from .guardrail_initializers import (
initialize_aim,
initialize_aporia,
initialize_bedrock,
initialize_guardrails_ai,
initialize_hide_secrets,
initialize_lakera,
initialize_presidio,
)
guardrail_registry = {
SupportedGuardrailIntegrations.APORIA.value: initialize_aporia,
SupportedGuardrailIntegrations.BEDROCK.value: initialize_bedrock,
SupportedGuardrailIntegrations.LAKERA.value: initialize_lakera,
SupportedGuardrailIntegrations.AIM.value: initialize_aim,
SupportedGuardrailIntegrations.PRESIDIO.value: initialize_presidio,
SupportedGuardrailIntegrations.HIDE_SECRETS.value: initialize_hide_secrets,
SupportedGuardrailIntegrations.GURDRAILS_AI.value: initialize_guardrails_ai,
}

View file

@ -1,4 +1,5 @@
import importlib
import os
from typing import Dict, List, Optional
import litellm
@ -9,14 +10,14 @@ from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_pr
# v2 implementation
from litellm.types.guardrails import (
Guardrail,
GuardrailEventHooks,
GuardrailItem,
GuardrailItemSpec,
LakeraCategoryThresholds,
LitellmParams,
SupportedGuardrailIntegrations,
)
from .guardrail_registry import guardrail_registry
all_guardrails: List[GuardrailItem] = []
@ -83,23 +84,18 @@ Map guardrail_name: <pre_call>, <post_call>, during_call
"""
def init_guardrails_v2( # noqa: PLR0915
def init_guardrails_v2(
all_guardrails: List[Dict],
config_file_path: Optional[str] = None,
):
# Convert the loaded data to the TypedDict structure
guardrail_list = []
# Parse each guardrail and replace environment variables
for guardrail in all_guardrails:
# Init litellm params for guardrail
litellm_params_data = guardrail["litellm_params"]
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
_litellm_params_kwargs = {
k: litellm_params_data[k] if k in litellm_params_data else None
for k in LitellmParams.__annotations__.keys()
k: litellm_params_data.get(k) for k in LitellmParams.__annotations__.keys()
}
litellm_params = LitellmParams(**_litellm_params_kwargs) # type: ignore
@ -113,157 +109,41 @@ def init_guardrails_v2( # noqa: PLR0915
)
litellm_params["category_thresholds"] = lakera_category_thresholds
if litellm_params["api_key"]:
if litellm_params["api_key"].startswith("os.environ/"):
if litellm_params["api_key"] and litellm_params["api_key"].startswith(
"os.environ/"
):
litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore
if litellm_params["api_base"]:
if litellm_params["api_base"].startswith("os.environ/"):
if litellm_params["api_base"] and litellm_params["api_base"].startswith(
"os.environ/"
):
litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore
# Init guardrail CustomLoggerClass
if litellm_params["guardrail"] == SupportedGuardrailIntegrations.APORIA.value:
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (
AporiaGuardrail,
)
guardrail_type = litellm_params["guardrail"]
_aporia_callback = AporiaGuardrail(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
)
litellm.callbacks.append(_aporia_callback) # type: ignore
elif (
litellm_params["guardrail"] == SupportedGuardrailIntegrations.BEDROCK.value
):
from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import (
BedrockGuardrail,
)
initializer = guardrail_registry.get(guardrail_type)
_bedrock_callback = BedrockGuardrail(
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
guardrailIdentifier=litellm_params["guardrailIdentifier"],
guardrailVersion=litellm_params["guardrailVersion"],
)
litellm.callbacks.append(_bedrock_callback) # type: ignore
elif litellm_params["guardrail"] == SupportedGuardrailIntegrations.LAKERA.value:
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
lakeraAI_Moderation,
)
_lakera_callback = lakeraAI_Moderation(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
category_thresholds=litellm_params.get("category_thresholds"),
)
litellm.callbacks.append(_lakera_callback) # type: ignore
elif litellm_params["guardrail"] == SupportedGuardrailIntegrations.AIM.value:
from litellm.proxy.guardrails.guardrail_hooks.aim import (
AimGuardrail,
)
_aim_callback = AimGuardrail(
api_base=litellm_params["api_base"],
api_key=litellm_params["api_key"],
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
)
litellm.callbacks.append(_aim_callback) # type: ignore
elif (
litellm_params["guardrail"] == SupportedGuardrailIntegrations.PRESIDIO.value
):
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
_OPTIONAL_PresidioPIIMasking,
)
_presidio_callback = _OPTIONAL_PresidioPIIMasking(
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
output_parse_pii=litellm_params["output_parse_pii"],
presidio_ad_hoc_recognizers=litellm_params[
"presidio_ad_hoc_recognizers"
],
mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
)
if litellm_params["output_parse_pii"] is True:
_success_callback = _OPTIONAL_PresidioPIIMasking(
output_parse_pii=True,
guardrail_name=guardrail["guardrail_name"],
event_hook=GuardrailEventHooks.post_call.value,
presidio_ad_hoc_recognizers=litellm_params[
"presidio_ad_hoc_recognizers"
],
)
litellm.callbacks.append(_success_callback) # type: ignore
litellm.callbacks.append(_presidio_callback) # type: ignore
elif (
litellm_params["guardrail"]
== SupportedGuardrailIntegrations.HIDE_SECRETS.value
):
from enterprise.enterprise_hooks.secret_detection import (
_ENTERPRISE_SecretDetection,
)
_secret_detection_object = _ENTERPRISE_SecretDetection(
detect_secrets_config=litellm_params.get("detect_secrets_config"),
event_hook=litellm_params["mode"],
guardrail_name=guardrail["guardrail_name"],
)
litellm.callbacks.append(_secret_detection_object) # type: ignore
elif (
litellm_params["guardrail"]
== SupportedGuardrailIntegrations.GURDRAILS_AI.value
):
from litellm.proxy.guardrails.guardrail_hooks.guardrails_ai import (
GuardrailsAI,
)
_guard_name = litellm_params.get("guard_name")
if _guard_name is None:
raise Exception(
"GuardrailsAIException - Please pass the Guardrails AI guard name via 'litellm_params::guard_name'"
)
_guardrails_ai_callback = GuardrailsAI(
api_base=litellm_params.get("api_base"),
guard_name=_guard_name,
guardrail_name=SupportedGuardrailIntegrations.GURDRAILS_AI.value,
)
litellm.callbacks.append(_guardrails_ai_callback) # type: ignore
elif (
isinstance(litellm_params["guardrail"], str)
and "." in litellm_params["guardrail"]
):
if config_file_path is None:
if initializer:
initializer(litellm_params, guardrail)
elif isinstance(guardrail_type, str) and "." in guardrail_type:
if not config_file_path:
raise Exception(
"GuardrailsAIException - Please pass the config_file_path to initialize_guardrails_v2"
)
import os
# Custom guardrail
_guardrail = litellm_params["guardrail"]
_file_name, _class_name = _guardrail.split(".")
_file_name, _class_name = guardrail_type.split(".")
verbose_proxy_logger.debug(
"Initializing custom guardrail: %s, file_name: %s, class_name: %s",
_guardrail,
guardrail_type,
_file_name,
_class_name,
)
directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, _file_name)
module_file_path += ".py"
module_file_path = os.path.join(directory, _file_name) + ".py"
spec = importlib.util.spec_from_file_location(_class_name, module_file_path) # type: ignore
if spec is None:
if not spec:
raise ImportError(
f"Could not find a module specification for {module_file_path}"
)
@ -275,10 +155,11 @@ def init_guardrails_v2( # noqa: PLR0915
_guardrail_callback = _guardrail_class(
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_guardrail_callback) # type: ignore
else:
raise ValueError(f"Unsupported guardrail: {litellm_params['guardrail']}")
raise ValueError(f"Unsupported guardrail: {guardrail_type}")
parsed_guardrail = Guardrail(
guardrail_name=guardrail["guardrail_name"],
@ -286,6 +167,5 @@ def init_guardrails_v2( # noqa: PLR0915
)
guardrail_list.append(parsed_guardrail)
guardrail["guardrail_name"]
# pretty print guardrail_list in green
print(f"\nGuardrail List:{guardrail_list}\n") # noqa

View file

@ -23,11 +23,6 @@ guardrails:
litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "during_call"
guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT"
- guardrail_name: "bedrock-post-guard"
litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "post_call"
guardrailIdentifier: ff6ujrregl1q
guardrailIdentifier: gf3sc1mzinjw
guardrailVersion: "DRAFT"
default_on: true

View file

@ -7,14 +7,13 @@ from typing_extensions import Required, TypedDict
"""
Pydantic object defining how to set guardrails on litellm proxy
litellm_settings:
guardrails:
- prompt_injection:
callbacks: [lakera_prompt_injection, prompt_injection_api_2]
default_on: true
enabled_roles: [system, user]
- detect_secrets:
callbacks: [hide_secrets]
guardrails:
- guardrail_name: "bedrock-pre-guard"
litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "during_call"
guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT"
default_on: true
"""
@ -104,6 +103,7 @@ class LitellmParams(TypedDict):
# guardrails ai params
guard_name: Optional[str]
default_on: Optional[bool]
class Guardrail(TypedDict, total=False):

View file

@ -194,3 +194,37 @@ def test_get_guardrails_list_response():
assert len(minimal_response.guardrails) == 1
assert minimal_response.guardrails[0].guardrail_name == "minimal-guard"
assert minimal_response.guardrails[0].guardrail_info is None
def test_default_on_guardrail():
# Test guardrail with default_on=True
guardrail = CustomGuardrail(
guardrail_name="test-guardrail",
event_hook=GuardrailEventHooks.pre_call,
default_on=True,
)
# Should run when event_type matches, even without explicit request
assert (
guardrail.should_run_guardrail(
{"metadata": {}}, # Empty metadata, no explicit guardrail request
GuardrailEventHooks.pre_call,
)
== True
)
# Should not run when event_type doesn't match
assert (
guardrail.should_run_guardrail({"metadata": {}}, GuardrailEventHooks.post_call)
== False
)
# Should run even when different guardrail explicitly requested
# run test-guardrail-5 and test-guardrail
assert (
guardrail.should_run_guardrail(
{"metadata": {"guardrails": ["test-guardrail-5"]}},
GuardrailEventHooks.pre_call,
)
== True
)