From 49847347d05e7f51c6e47ea98baabdde5e17b9ef Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 20 Feb 2024 20:31:32 -0800 Subject: [PATCH] fix(llm_guard.py): add streaming hook for moderation calls --- enterprise/enterprise_hooks/llm_guard.py | 19 ++++++++----------- litellm/integrations/custom_logger.py | 7 +++++++ litellm/proxy/utils.py | 21 +++++++++++++++++++++ litellm/utils.py | 14 -------------- 4 files changed, 36 insertions(+), 25 deletions(-) diff --git a/enterprise/enterprise_hooks/llm_guard.py b/enterprise/enterprise_hooks/llm_guard.py index c000f6011..58eb71ee3 100644 --- a/enterprise/enterprise_hooks/llm_guard.py +++ b/enterprise/enterprise_hooks/llm_guard.py @@ -101,19 +101,16 @@ class _ENTERPRISE_LLMGuard(CustomLogger): - Use the sanitized prompt returned - LLM Guard can handle things like PII Masking, etc. """ - if "messages" in data: - safety_check_messages = data["messages"][ - -1 - ] # get the last response - llama guard has a 4k token limit - if ( - isinstance(safety_check_messages, dict) - and "content" in safety_check_messages - and isinstance(safety_check_messages["content"], str) - ): - await self.moderation_check(safety_check_messages["content"]) - return data + async def async_post_call_streaming_hook( + self, user_api_key_dict: UserAPIKeyAuth, response: str + ): + if response is not None: + await self.moderation_check(text=response) + + return response + # llm_guard = _ENTERPRISE_LLMGuard() diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index c29c964fc..40242f5c0 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -75,6 +75,13 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac async def async_moderation_hook(self, data: dict): pass + async def async_post_call_streaming_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response: str, + ): + pass + #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function def log_input_event(self, model, messages, kwargs, print_verbose, callback_func): diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 78c1e4b63..3cad1777c 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -401,6 +401,27 @@ class ProxyLogging: raise e return new_response + async def post_call_streaming_hook( + self, + response: str, + user_api_key_dict: UserAPIKeyAuth, + ): + """ + - Check outgoing streaming response uptil that point + - Run through moderation check + - Reject request if it fails moderation check + """ + new_response = copy.deepcopy(response) + for callback in litellm.callbacks: + try: + if isinstance(callback, CustomLogger): + await callback.async_post_call_streaming_hook( + user_api_key_dict=user_api_key_dict, response=new_response + ) + except Exception as e: + raise e + return new_response + ### DB CONNECTOR ### # Define the retry decorator with backoff strategy diff --git a/litellm/utils.py b/litellm/utils.py index 982462e3f..2b3764b1e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -909,20 +909,6 @@ class Logging: f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" ) - if litellm.max_budget and self.stream: - start_time = self.start_time - end_time = ( - self.start_time - ) # no time has passed as the call hasn't been made yet - time_diff = (end_time - start_time).total_seconds() - float_diff = float(time_diff) - litellm._current_cost += litellm.completion_cost( - model=self.model, - prompt="".join(message["content"] for message in self.messages), - completion="", - total_time=float_diff, - ) - # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made callbacks = litellm.input_callback + self.dynamic_input_callbacks for callback in callbacks: