add custom guardrail reference

This commit is contained in:
Ishaan Jaff 2024-08-23 08:32:07 -07:00
parent e62d0c7922
commit af92cff44d
4 changed files with 342 additions and 39 deletions

View file

@ -30,6 +30,7 @@ from litellm._logging import verbose_proxy_logger
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm.caching import DualCache, RedisCache
from litellm.exceptions import RejectedRequestError
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.slack_alerting import SlackAlerting
from litellm.litellm_core_utils.core_helpers import (
@ -344,6 +345,23 @@ class ProxyLogging:
ttl=alerting_threshold,
)
async def process_pre_call_hook_response(self, response, data, call_type):
if isinstance(response, Exception):
raise response
if isinstance(response, dict):
return response
if isinstance(response, str):
if call_type in ["completion", "text_completion"]:
raise RejectedRequestError(
message=response,
model=data.get("model", ""),
llm_provider="",
request_data=data,
)
else:
raise HTTPException(status_code=400, detail={"error": response})
return data
# The actual implementation of the function
async def pre_call_hook(
self,
@ -382,7 +400,33 @@ class ProxyLogging:
)
else:
_callback = callback # type: ignore
if (
_callback is not None
and isinstance(_callback, CustomGuardrail)
and "pre_call_hook" in vars(_callback.__class__)
):
from litellm.types.guardrails import GuardrailEventHooks
if (
_callback.should_run_guardrail(
data=data, event_type=GuardrailEventHooks.pre_call
)
is not True
):
continue
response = await _callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"],
data=data,
call_type=call_type,
)
if response is not None:
data = await self.process_pre_call_hook_response(
response=response, data=data, call_type=call_type
)
elif (
_callback is not None
and isinstance(_callback, CustomLogger)
and "async_pre_call_hook" in vars(_callback.__class__)
@ -394,25 +438,9 @@ class ProxyLogging:
call_type=call_type,
)
if response is not None:
if isinstance(response, Exception):
raise response
elif isinstance(response, dict):
data = response
elif isinstance(response, str):
if (
call_type == "completion"
or call_type == "text_completion"
):
raise RejectedRequestError(
message=response,
model=data.get("model", ""),
llm_provider="",
request_data=data,
)
else:
raise HTTPException(
status_code=400, detail={"error": response}
)
data = await self.process_pre_call_hook_response(
response=response, data=data, call_type=call_type
)
return data
except Exception as e:
@ -431,11 +459,30 @@ class ProxyLogging:
],
):
"""
Runs the CustomLogger's async_moderation_hook()
Runs the CustomGuardrail's async_moderation_hook()
"""
for callback in litellm.callbacks:
try:
if isinstance(callback, CustomLogger):
if isinstance(callback, CustomGuardrail):
################################################################
# Check if guardrail should be run for GuardrailEventHooks.during_call hook
################################################################
# V1 implementation - backwards compatibility
if callback.event_hook is None:
if callback.moderation_check == "pre_call":
return
else:
# Main - V2 Guardrails implementation
from litellm.types.guardrails import GuardrailEventHooks
if (
callback.should_run_guardrail(
data=data, event_type=GuardrailEventHooks.during_call
)
is not True
):
continue
await callback.async_moderation_hook(
data=data,
user_api_key_dict=user_api_key_dict,
@ -737,12 +784,36 @@ class ProxyLogging:
)
else:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
await _callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict,
data=data,
response=response,
)
if _callback is not None:
############## Handle Guardrails ########################################
#############################################################################
if isinstance(callback, CustomGuardrail):
# Main - V2 Guardrails implementation
from litellm.types.guardrails import GuardrailEventHooks
if (
callback.should_run_guardrail(
data=data, event_type=GuardrailEventHooks.post_call
)
is not True
):
continue
await callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict,
data=data,
response=response,
)
############ Handle CustomLogger ###############################
#################################################################
elif isinstance(_callback, CustomLogger):
await _callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict,
data=data,
response=response,
)
except Exception as e:
raise e
return response