mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(Feat) - allow setting default_on
guardrails (#7973)
* test_default_on_guardrail * update debug on custom guardrail * refactor guardrails init * guardrail registry * allow switching guardrails default_on * fix circle import issue * fix bedrock applying guardrails where content is a list * fix unused import * docs default on guardrail * docs fix per api key
This commit is contained in:
parent
04401c7080
commit
d1bc955d97
10 changed files with 292 additions and 325 deletions
|
@ -13,11 +13,22 @@ class CustomGuardrail(CustomLogger):
|
|||
guardrail_name: Optional[str] = None,
|
||||
supported_event_hooks: Optional[List[GuardrailEventHooks]] = None,
|
||||
event_hook: Optional[GuardrailEventHooks] = None,
|
||||
default_on: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the CustomGuardrail class
|
||||
|
||||
Args:
|
||||
guardrail_name: The name of the guardrail. This is the name used in your requests.
|
||||
supported_event_hooks: The event hooks that the guardrail supports
|
||||
event_hook: The event hook to run the guardrail on
|
||||
default_on: If True, the guardrail will be run by default on all requests
|
||||
"""
|
||||
self.guardrail_name = guardrail_name
|
||||
self.supported_event_hooks = supported_event_hooks
|
||||
self.event_hook: Optional[GuardrailEventHooks] = event_hook
|
||||
self.default_on: bool = default_on
|
||||
|
||||
if supported_event_hooks:
|
||||
## validate event_hook is in supported_event_hooks
|
||||
|
@ -51,16 +62,25 @@ class CustomGuardrail(CustomLogger):
|
|||
return False
|
||||
|
||||
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
||||
"""
|
||||
Returns True if the guardrail should be run on the event_type
|
||||
"""
|
||||
requested_guardrails = self.get_guardrail_from_metadata(data)
|
||||
|
||||
verbose_logger.debug(
|
||||
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s",
|
||||
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s self.default_on= %s",
|
||||
self.guardrail_name,
|
||||
event_type,
|
||||
self.event_hook,
|
||||
requested_guardrails,
|
||||
self.default_on,
|
||||
)
|
||||
|
||||
if self.default_on is True:
|
||||
if self._event_hook_is_event_type(event_type):
|
||||
return True
|
||||
return False
|
||||
|
||||
if (
|
||||
self.event_hook
|
||||
and not self._guardrail_is_in_requested_guardrails(requested_guardrails)
|
||||
|
@ -73,6 +93,15 @@ class CustomGuardrail(CustomLogger):
|
|||
|
||||
return True
|
||||
|
||||
def _event_hook_is_event_type(self, event_type: GuardrailEventHooks) -> bool:
|
||||
"""
|
||||
Returns True if the event_hook is the same as the event_type
|
||||
|
||||
eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True
|
||||
eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False
|
||||
"""
|
||||
return self.event_hook == event_type.value
|
||||
|
||||
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
|
||||
"""
|
||||
Returns `extra_body` to be added to the request body for the Guardrail API call
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue