test(base_llm_unit_tests.py): add test to ensure drop params is respe… (#8224)

* test(base_llm_unit_tests.py): add test to ensure drop params is respected

* fix(types/prometheus.py): use typing_extensions for python3.8 compatibility

* build: add cherry picked commits
This commit is contained in:
Krish Dholakia 2025-02-03 16:04:44 -08:00 committed by GitHub
parent d60d3ee970
commit c8494abdea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 250 additions and 71 deletions

View file

@ -12,7 +12,9 @@ class CustomGuardrail(CustomLogger):
self,
guardrail_name: Optional[str] = None,
supported_event_hooks: Optional[List[GuardrailEventHooks]] = None,
event_hook: Optional[GuardrailEventHooks] = None,
event_hook: Optional[
Union[GuardrailEventHooks, List[GuardrailEventHooks]]
] = None,
default_on: bool = False,
**kwargs,
):
@ -27,16 +29,34 @@ class CustomGuardrail(CustomLogger):
"""
self.guardrail_name = guardrail_name
self.supported_event_hooks = supported_event_hooks
self.event_hook: Optional[GuardrailEventHooks] = event_hook
self.event_hook: Optional[
Union[GuardrailEventHooks, List[GuardrailEventHooks]]
] = event_hook
self.default_on: bool = default_on
if supported_event_hooks:
## validate event_hook is in supported_event_hooks
if event_hook and event_hook not in supported_event_hooks:
self._validate_event_hook(event_hook, supported_event_hooks)
super().__init__(**kwargs)
def _validate_event_hook(
self,
event_hook: Optional[Union[GuardrailEventHooks, List[GuardrailEventHooks]]],
supported_event_hooks: List[GuardrailEventHooks],
) -> None:
if event_hook is None:
return
if isinstance(event_hook, list):
for hook in event_hook:
if hook not in supported_event_hooks:
raise ValueError(
f"Event hook {hook} is not in the supported event hooks {supported_event_hooks}"
)
elif isinstance(event_hook, GuardrailEventHooks):
if event_hook not in supported_event_hooks:
raise ValueError(
f"Event hook {event_hook} is not in the supported event hooks {supported_event_hooks}"
)
super().__init__(**kwargs)
def get_guardrail_from_metadata(
self, data: dict
@ -88,7 +108,7 @@ class CustomGuardrail(CustomLogger):
):
return False
if self.event_hook and self.event_hook != event_type.value:
if not self._event_hook_is_event_type(event_type):
return False
return True
@ -100,6 +120,11 @@ class CustomGuardrail(CustomLogger):
eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True
eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False
"""
if self.event_hook is None:
return True
if isinstance(self.event_hook, list):
return event_type.value in self.event_hook
return self.event_hook == event_type.value
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict: