mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
add custom guardrail reference
This commit is contained in:
parent
e62d0c7922
commit
af92cff44d
4 changed files with 342 additions and 39 deletions
|
@ -30,6 +30,7 @@ from litellm._logging import verbose_proxy_logger
|
|||
from litellm._service_logger import ServiceLogging, ServiceTypes
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.exceptions import RejectedRequestError
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.slack_alerting import SlackAlerting
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
|
@ -344,6 +345,23 @@ class ProxyLogging:
|
|||
ttl=alerting_threshold,
|
||||
)
|
||||
|
||||
async def process_pre_call_hook_response(self, response, data, call_type):
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
if isinstance(response, dict):
|
||||
return response
|
||||
if isinstance(response, str):
|
||||
if call_type in ["completion", "text_completion"]:
|
||||
raise RejectedRequestError(
|
||||
message=response,
|
||||
model=data.get("model", ""),
|
||||
llm_provider="",
|
||||
request_data=data,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail={"error": response})
|
||||
return data
|
||||
|
||||
# The actual implementation of the function
|
||||
async def pre_call_hook(
|
||||
self,
|
||||
|
@ -382,7 +400,33 @@ class ProxyLogging:
|
|||
)
|
||||
else:
|
||||
_callback = callback # type: ignore
|
||||
|
||||
if (
|
||||
_callback is not None
|
||||
and isinstance(_callback, CustomGuardrail)
|
||||
and "pre_call_hook" in vars(_callback.__class__)
|
||||
):
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
if (
|
||||
_callback.should_run_guardrail(
|
||||
data=data, event_type=GuardrailEventHooks.pre_call
|
||||
)
|
||||
is not True
|
||||
):
|
||||
continue
|
||||
response = await _callback.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=self.call_details["user_api_key_cache"],
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
)
|
||||
if response is not None:
|
||||
data = await self.process_pre_call_hook_response(
|
||||
response=response, data=data, call_type=call_type
|
||||
)
|
||||
|
||||
elif (
|
||||
_callback is not None
|
||||
and isinstance(_callback, CustomLogger)
|
||||
and "async_pre_call_hook" in vars(_callback.__class__)
|
||||
|
@ -394,25 +438,9 @@ class ProxyLogging:
|
|||
call_type=call_type,
|
||||
)
|
||||
if response is not None:
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
elif isinstance(response, dict):
|
||||
data = response
|
||||
elif isinstance(response, str):
|
||||
if (
|
||||
call_type == "completion"
|
||||
or call_type == "text_completion"
|
||||
):
|
||||
raise RejectedRequestError(
|
||||
message=response,
|
||||
model=data.get("model", ""),
|
||||
llm_provider="",
|
||||
request_data=data,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": response}
|
||||
)
|
||||
data = await self.process_pre_call_hook_response(
|
||||
response=response, data=data, call_type=call_type
|
||||
)
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
|
@ -431,11 +459,30 @@ class ProxyLogging:
|
|||
],
|
||||
):
|
||||
"""
|
||||
Runs the CustomLogger's async_moderation_hook()
|
||||
Runs the CustomGuardrail's async_moderation_hook()
|
||||
"""
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
if isinstance(callback, CustomGuardrail):
|
||||
################################################################
|
||||
# Check if guardrail should be run for GuardrailEventHooks.during_call hook
|
||||
################################################################
|
||||
|
||||
# V1 implementation - backwards compatibility
|
||||
if callback.event_hook is None:
|
||||
if callback.moderation_check == "pre_call":
|
||||
return
|
||||
else:
|
||||
# Main - V2 Guardrails implementation
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
if (
|
||||
callback.should_run_guardrail(
|
||||
data=data, event_type=GuardrailEventHooks.during_call
|
||||
)
|
||||
is not True
|
||||
):
|
||||
continue
|
||||
await callback.async_moderation_hook(
|
||||
data=data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
|
@ -737,12 +784,36 @@ class ProxyLogging:
|
|||
)
|
||||
else:
|
||||
_callback = callback # type: ignore
|
||||
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||
await _callback.async_post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
response=response,
|
||||
)
|
||||
|
||||
if _callback is not None:
|
||||
############## Handle Guardrails ########################################
|
||||
#############################################################################
|
||||
if isinstance(callback, CustomGuardrail):
|
||||
# Main - V2 Guardrails implementation
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
if (
|
||||
callback.should_run_guardrail(
|
||||
data=data, event_type=GuardrailEventHooks.post_call
|
||||
)
|
||||
is not True
|
||||
):
|
||||
continue
|
||||
|
||||
await callback.async_post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
response=response,
|
||||
)
|
||||
|
||||
############ Handle CustomLogger ###############################
|
||||
#################################################################
|
||||
elif isinstance(_callback, CustomLogger):
|
||||
await _callback.async_post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
response=response,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue