fix use guardrail for pre call hook

This commit is contained in:
Ishaan Jaff 2024-08-23 09:34:08 -07:00
parent 6e3f27cf69
commit a8e192a868
4 changed files with 30 additions and 54 deletions

View file

@ -18,16 +18,16 @@ class CustomGuardrail(CustomLogger):
super().__init__(**kwargs) super().__init__(**kwargs)
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool: def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
metadata = data.get("metadata") or {}
requested_guardrails = metadata.get("guardrails") or []
verbose_logger.debug( verbose_logger.debug(
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s", "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s",
self.guardrail_name, self.guardrail_name,
event_type, event_type,
self.event_hook, self.event_hook,
requested_guardrails,
) )
metadata = data.get("metadata") or {}
requested_guardrails = metadata.get("guardrails") or []
if self.guardrail_name not in requested_guardrails: if self.guardrail_name not in requested_guardrails:
return False return False

View file

@ -1,19 +1,5 @@
import os
import sys
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import json
import sys
import traceback
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from fastapi import HTTPException
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
@ -48,7 +34,13 @@ class myCustomGuardrail(CustomGuardrail):
"pass_through_endpoint", "pass_through_endpoint",
], ],
) -> Optional[Union[Exception, str, dict]]: ) -> Optional[Union[Exception, str, dict]]:
# In this guardrail, if a user inputs `litellm` we will mask it. """
Runs before the LLM API call
Runs on only Input
Use this if you want to MODIFY the input
"""
# In this guardrail, if a user inputs `litellm` we will mask it and then send it to the LLM
_messages = data.get("messages") _messages = data.get("messages")
if _messages: if _messages:
for message in _messages: for message in _messages:
@ -73,6 +65,8 @@ class myCustomGuardrail(CustomGuardrail):
""" """
Runs in parallel to LLM API call Runs in parallel to LLM API call
Runs on only Input Runs on only Input
This can NOT modify the input, only used to reject or accept a call before going to LLM API
""" """
# this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call # this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call
@ -83,13 +77,7 @@ class myCustomGuardrail(CustomGuardrail):
_content = message.get("content") _content = message.get("content")
if isinstance(_content, str): if isinstance(_content, str):
if "litellm" in _content.lower(): if "litellm" in _content.lower():
_content = _content.replace("litellm", "********") raise ValueError("Guardrail failed words - `litellm` detected")
message["content"] = _content
verbose_proxy_logger.debug(
"async_pre_call_hook: Message after masking %s", _messages
)
pass
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
@ -100,6 +88,8 @@ class myCustomGuardrail(CustomGuardrail):
""" """
Runs on response from LLM API call Runs on response from LLM API call
It can be used to reject a response
If a response contains the word "coffee" -> we will raise an exception If a response contains the word "coffee" -> we will raise an exception
""" """
verbose_proxy_logger.debug("async_pre_call_hook response: %s", response) verbose_proxy_logger.debug("async_pre_call_hook response: %s", response)

View file

@ -1,19 +1,5 @@
import os
import sys
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import json
import sys
import traceback
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from fastapi import HTTPException
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
@ -48,7 +34,13 @@ class myCustomGuardrail(CustomGuardrail):
"pass_through_endpoint", "pass_through_endpoint",
], ],
) -> Optional[Union[Exception, str, dict]]: ) -> Optional[Union[Exception, str, dict]]:
# In this guardrail, if a user inputs `litellm` we will mask it. """
Runs before the LLM API call
Runs on only Input
Use this if you want to MODIFY the input
"""
# In this guardrail, if a user inputs `litellm` we will mask it and then send it to the LLM
_messages = data.get("messages") _messages = data.get("messages")
if _messages: if _messages:
for message in _messages: for message in _messages:
@ -73,6 +65,8 @@ class myCustomGuardrail(CustomGuardrail):
""" """
Runs in parallel to LLM API call Runs in parallel to LLM API call
Runs on only Input Runs on only Input
This can NOT modify the input, only used to reject or accept a call before going to LLM API
""" """
# this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call # this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call
@ -83,13 +77,7 @@ class myCustomGuardrail(CustomGuardrail):
_content = message.get("content") _content = message.get("content")
if isinstance(_content, str): if isinstance(_content, str):
if "litellm" in _content.lower(): if "litellm" in _content.lower():
_content = _content.replace("litellm", "********") raise ValueError("Guardrail failed words - `litellm` detected")
message["content"] = _content
verbose_proxy_logger.debug(
"async_pre_call_hook: Message after masking %s", _messages
)
pass
async def async_post_call_success_hook( async def async_post_call_success_hook(
self, self,
@ -100,6 +88,8 @@ class myCustomGuardrail(CustomGuardrail):
""" """
Runs on response from LLM API call Runs on response from LLM API call
It can be used to reject a response
If a response contains the word "coffee" -> we will raise an exception If a response contains the word "coffee" -> we will raise an exception
""" """
verbose_proxy_logger.debug("async_pre_call_hook response: %s", response) verbose_proxy_logger.debug("async_pre_call_hook response: %s", response)

View file

@ -393,7 +393,7 @@ class ProxyLogging:
try: try:
for callback in litellm.callbacks: for callback in litellm.callbacks:
_callback: Optional[CustomLogger] = None _callback = None
if isinstance(callback, str): if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
callback callback
@ -401,11 +401,7 @@ class ProxyLogging:
else: else:
_callback = callback # type: ignore _callback = callback # type: ignore
if ( if _callback is not None and isinstance(_callback, CustomGuardrail):
_callback is not None
and isinstance(_callback, CustomGuardrail)
and "pre_call_hook" in vars(_callback.__class__)
):
from litellm.types.guardrails import GuardrailEventHooks from litellm.types.guardrails import GuardrailEventHooks
if ( if (