(Feat) - allow setting default_on guardrails (#7973)

* test_default_on_guardrail

* update debug on custom guardrail

* refactor guardrails init

* guardrail registry

* allow switching guardrails default_on

* fix circle import issue

* fix bedrock applying guardrails where content is a list

* fix unused import

* docs default on guardrail

* docs fix per api key
This commit is contained in:
Ishaan Jaff 2025-01-24 10:14:05 -08:00 committed by GitHub
parent 04401c7080
commit d1bc955d97
10 changed files with 292 additions and 325 deletions

View file

@ -13,11 +13,22 @@ class CustomGuardrail(CustomLogger):
guardrail_name: Optional[str] = None,
supported_event_hooks: Optional[List[GuardrailEventHooks]] = None,
event_hook: Optional[GuardrailEventHooks] = None,
default_on: bool = False,
**kwargs,
):
"""
Initialize the CustomGuardrail class
Args:
guardrail_name: The name of the guardrail. This is the name used in your requests.
supported_event_hooks: The event hooks that the guardrail supports
event_hook: The event hook to run the guardrail on
default_on: If True, the guardrail will be run by default on all requests
"""
self.guardrail_name = guardrail_name
self.supported_event_hooks = supported_event_hooks
self.event_hook: Optional[GuardrailEventHooks] = event_hook
self.default_on: bool = default_on
if supported_event_hooks:
## validate event_hook is in supported_event_hooks
@ -51,16 +62,25 @@ class CustomGuardrail(CustomLogger):
return False
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
"""
Returns True if the guardrail should be run on the event_type
"""
requested_guardrails = self.get_guardrail_from_metadata(data)
verbose_logger.debug(
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s",
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s self.default_on= %s",
self.guardrail_name,
event_type,
self.event_hook,
requested_guardrails,
self.default_on,
)
if self.default_on is True:
if self._event_hook_is_event_type(event_type):
return True
return False
if (
self.event_hook
and not self._guardrail_is_in_requested_guardrails(requested_guardrails)
@ -73,6 +93,15 @@ class CustomGuardrail(CustomLogger):
return True
def _event_hook_is_event_type(self, event_type: GuardrailEventHooks) -> bool:
"""
Returns True if the event_hook is the same as the event_type
eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True
eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False
"""
return self.event_hook == event_type.value
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
"""
Returns `extra_body` to be added to the request body for the Guardrail API call