From 6eb8fe35c8e0642c3efd9427b596d113d5697de1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 18 Apr 2024 07:32:29 -0700 Subject: [PATCH] fix(utils.py): function_setup empty message fix fixes https://github.com/BerriAI/litellm/issues/2858 --- litellm/tests/test_function_setup.py | 33 +++ litellm/utils.py | 400 +++++++++++++-------------- 2 files changed, 231 insertions(+), 202 deletions(-) create mode 100644 litellm/tests/test_function_setup.py diff --git a/litellm/tests/test_function_setup.py b/litellm/tests/test_function_setup.py new file mode 100644 index 0000000000..4be36bacca --- /dev/null +++ b/litellm/tests/test_function_setup.py @@ -0,0 +1,33 @@ +# What is this? +## Unit tests for the 'function_setup()' function +import sys, os +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os, io + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the, system path +import pytest, uuid +from litellm.utils import function_setup, Rules +from datetime import datetime + + +def test_empty_content(): + """ + Make a chat completions request with empty content -> expect this to work + """ + rules_obj = Rules() + + def completion(): + pass + + function_setup( + original_function=completion, + rules_obj=rules_obj, + start_time=datetime.now(), + messages=[], + litellm_call_id=str(uuid.uuid4()), + ) diff --git a/litellm/utils.py b/litellm/utils.py index 56320abbe5..8e7c31867f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2381,210 +2381,202 @@ class Rules: ####### CLIENT ################### # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking +def function_setup( + original_function, rules_obj, start_time, *args, **kwargs +): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. + try: + global callback_list, add_breadcrumb, user_logger_fn, Logging + function_id = kwargs["id"] if "id" in kwargs else None + if len(litellm.callbacks) > 0: + for callback in litellm.callbacks: + if callback not in litellm.input_callback: + litellm.input_callback.append(callback) + if callback not in litellm.success_callback: + litellm.success_callback.append(callback) + if callback not in litellm.failure_callback: + litellm.failure_callback.append(callback) + if callback not in litellm._async_success_callback: + litellm._async_success_callback.append(callback) + if callback not in litellm._async_failure_callback: + litellm._async_failure_callback.append(callback) + print_verbose( + f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}" + ) + if ( + len(litellm.input_callback) > 0 + or len(litellm.success_callback) > 0 + or len(litellm.failure_callback) > 0 + ) and len(callback_list) == 0: + callback_list = list( + set( + litellm.input_callback + + litellm.success_callback + + litellm.failure_callback + ) + ) + set_callbacks(callback_list=callback_list, function_id=function_id) + ## ASYNC CALLBACKS + if len(litellm.input_callback) > 0: + removed_async_items = [] + for index, callback in enumerate(litellm.input_callback): + if inspect.iscoroutinefunction(callback): + litellm._async_input_callback.append(callback) + removed_async_items.append(index) + + # Pop the async items from input_callback in reverse order to avoid index issues + for index in reversed(removed_async_items): + litellm.input_callback.pop(index) + + if len(litellm.success_callback) > 0: + removed_async_items = [] + for index, callback in enumerate(litellm.success_callback): + if inspect.iscoroutinefunction(callback): + litellm._async_success_callback.append(callback) + removed_async_items.append(index) + elif callback == "dynamodb": + # dynamo is an async callback, it's used for the proxy and needs to be async + # we only support async dynamo db logging for acompletion/aembedding since that's used on proxy + litellm._async_success_callback.append(callback) + removed_async_items.append(index) + + # Pop the async items from success_callback in reverse order to avoid index issues + for index in reversed(removed_async_items): + litellm.success_callback.pop(index) + + if len(litellm.failure_callback) > 0: + removed_async_items = [] + for index, callback in enumerate(litellm.failure_callback): + if inspect.iscoroutinefunction(callback): + litellm._async_failure_callback.append(callback) + removed_async_items.append(index) + + # Pop the async items from failure_callback in reverse order to avoid index issues + for index in reversed(removed_async_items): + litellm.failure_callback.pop(index) + ### DYNAMIC CALLBACKS ### + dynamic_success_callbacks = None + dynamic_async_success_callbacks = None + if kwargs.get("success_callback", None) is not None and isinstance( + kwargs["success_callback"], list + ): + removed_async_items = [] + for index, callback in enumerate(kwargs["success_callback"]): + if ( + inspect.iscoroutinefunction(callback) + or callback == "dynamodb" + or callback == "s3" + ): + if dynamic_async_success_callbacks is not None and isinstance( + dynamic_async_success_callbacks, list + ): + dynamic_async_success_callbacks.append(callback) + else: + dynamic_async_success_callbacks = [callback] + removed_async_items.append(index) + # Pop the async items from success_callback in reverse order to avoid index issues + for index in reversed(removed_async_items): + kwargs["success_callback"].pop(index) + dynamic_success_callbacks = kwargs.pop("success_callback") + + if add_breadcrumb: + add_breadcrumb( + category="litellm.llm_call", + message=f"Positional Args: {args}, Keyword Args: {kwargs}", + level="info", + ) + if "logger_fn" in kwargs: + user_logger_fn = kwargs["logger_fn"] + # INIT LOGGER - for user-specified integrations + model = args[0] if len(args) > 0 else kwargs.get("model", None) + call_type = original_function.__name__ + if ( + call_type == CallTypes.completion.value + or call_type == CallTypes.acompletion.value + ): + messages = None + if len(args) > 1: + messages = args[1] + elif kwargs.get("messages", None): + messages = kwargs["messages"] + ### PRE-CALL RULES ### + if ( + isinstance(messages, list) + and len(messages) > 0 + and isinstance(messages[0], dict) + and "content" in messages[0] + ): + rules_obj.pre_call_rules( + input="".join( + m.get("content", "") + for m in messages + if "content" in m and isinstance(m["content"], str) + ), + model=model, + ) + elif ( + call_type == CallTypes.embedding.value + or call_type == CallTypes.aembedding.value + ): + messages = args[1] if len(args) > 1 else kwargs["input"] + elif ( + call_type == CallTypes.image_generation.value + or call_type == CallTypes.aimage_generation.value + ): + messages = args[0] if len(args) > 0 else kwargs["prompt"] + elif ( + call_type == CallTypes.moderation.value + or call_type == CallTypes.amoderation.value + ): + messages = args[1] if len(args) > 1 else kwargs["input"] + elif ( + call_type == CallTypes.atext_completion.value + or call_type == CallTypes.text_completion.value + ): + messages = args[0] if len(args) > 0 else kwargs["prompt"] + elif ( + call_type == CallTypes.atranscription.value + or call_type == CallTypes.transcription.value + ): + _file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"] + messages = "audio_file" + stream = True if "stream" in kwargs and kwargs["stream"] == True else False + logging_obj = Logging( + model=model, + messages=messages, + stream=stream, + litellm_call_id=kwargs["litellm_call_id"], + function_id=function_id, + call_type=call_type, + start_time=start_time, + dynamic_success_callbacks=dynamic_success_callbacks, + dynamic_async_success_callbacks=dynamic_async_success_callbacks, + langfuse_public_key=kwargs.pop("langfuse_public_key", None), + langfuse_secret=kwargs.pop("langfuse_secret", None), + ) + ## check if metadata is passed in + litellm_params = {"api_base": ""} + if "metadata" in kwargs: + litellm_params["metadata"] = kwargs["metadata"] + logging_obj.update_environment_variables( + model=model, + user="", + optional_params={}, + litellm_params=litellm_params, + ) + return logging_obj, kwargs + except Exception as e: + import logging + + logging.debug( + f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}" + ) + raise e + + def client(original_function): global liteDebuggerClient, get_all_keys rules_obj = Rules() - def function_setup( - start_time, *args, **kwargs - ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. - try: - global callback_list, add_breadcrumb, user_logger_fn, Logging - function_id = kwargs["id"] if "id" in kwargs else None - if litellm.use_client or ( - "use_client" in kwargs and kwargs["use_client"] == True - ): - if "lite_debugger" not in litellm.input_callback: - litellm.input_callback.append("lite_debugger") - if "lite_debugger" not in litellm.success_callback: - litellm.success_callback.append("lite_debugger") - if "lite_debugger" not in litellm.failure_callback: - litellm.failure_callback.append("lite_debugger") - if len(litellm.callbacks) > 0: - for callback in litellm.callbacks: - if callback not in litellm.input_callback: - litellm.input_callback.append(callback) - if callback not in litellm.success_callback: - litellm.success_callback.append(callback) - if callback not in litellm.failure_callback: - litellm.failure_callback.append(callback) - if callback not in litellm._async_success_callback: - litellm._async_success_callback.append(callback) - if callback not in litellm._async_failure_callback: - litellm._async_failure_callback.append(callback) - print_verbose( - f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}" - ) - if ( - len(litellm.input_callback) > 0 - or len(litellm.success_callback) > 0 - or len(litellm.failure_callback) > 0 - ) and len(callback_list) == 0: - callback_list = list( - set( - litellm.input_callback - + litellm.success_callback - + litellm.failure_callback - ) - ) - set_callbacks(callback_list=callback_list, function_id=function_id) - ## ASYNC CALLBACKS - if len(litellm.input_callback) > 0: - removed_async_items = [] - for index, callback in enumerate(litellm.input_callback): - if inspect.iscoroutinefunction(callback): - litellm._async_input_callback.append(callback) - removed_async_items.append(index) - - # Pop the async items from input_callback in reverse order to avoid index issues - for index in reversed(removed_async_items): - litellm.input_callback.pop(index) - - if len(litellm.success_callback) > 0: - removed_async_items = [] - for index, callback in enumerate(litellm.success_callback): - if inspect.iscoroutinefunction(callback): - litellm._async_success_callback.append(callback) - removed_async_items.append(index) - elif callback == "dynamodb": - # dynamo is an async callback, it's used for the proxy and needs to be async - # we only support async dynamo db logging for acompletion/aembedding since that's used on proxy - litellm._async_success_callback.append(callback) - removed_async_items.append(index) - - # Pop the async items from success_callback in reverse order to avoid index issues - for index in reversed(removed_async_items): - litellm.success_callback.pop(index) - - if len(litellm.failure_callback) > 0: - removed_async_items = [] - for index, callback in enumerate(litellm.failure_callback): - if inspect.iscoroutinefunction(callback): - litellm._async_failure_callback.append(callback) - removed_async_items.append(index) - - # Pop the async items from failure_callback in reverse order to avoid index issues - for index in reversed(removed_async_items): - litellm.failure_callback.pop(index) - ### DYNAMIC CALLBACKS ### - dynamic_success_callbacks = None - dynamic_async_success_callbacks = None - if kwargs.get("success_callback", None) is not None and isinstance( - kwargs["success_callback"], list - ): - removed_async_items = [] - for index, callback in enumerate(kwargs["success_callback"]): - if ( - inspect.iscoroutinefunction(callback) - or callback == "dynamodb" - or callback == "s3" - ): - if dynamic_async_success_callbacks is not None and isinstance( - dynamic_async_success_callbacks, list - ): - dynamic_async_success_callbacks.append(callback) - else: - dynamic_async_success_callbacks = [callback] - removed_async_items.append(index) - # Pop the async items from success_callback in reverse order to avoid index issues - for index in reversed(removed_async_items): - kwargs["success_callback"].pop(index) - dynamic_success_callbacks = kwargs.pop("success_callback") - - if add_breadcrumb: - add_breadcrumb( - category="litellm.llm_call", - message=f"Positional Args: {args}, Keyword Args: {kwargs}", - level="info", - ) - if "logger_fn" in kwargs: - user_logger_fn = kwargs["logger_fn"] - # INIT LOGGER - for user-specified integrations - model = args[0] if len(args) > 0 else kwargs.get("model", None) - call_type = original_function.__name__ - if ( - call_type == CallTypes.completion.value - or call_type == CallTypes.acompletion.value - ): - messages = None - if len(args) > 1: - messages = args[1] - elif kwargs.get("messages", None): - messages = kwargs["messages"] - ### PRE-CALL RULES ### - if ( - isinstance(messages, list) - and len(messages) > 0 - and isinstance(messages[0], dict) - and "content" in messages[0] - ): - rules_obj.pre_call_rules( - input="".join( - m.get("content", "") - for m in messages - if isinstance(m["content"], str) - ), - model=model, - ) - elif ( - call_type == CallTypes.embedding.value - or call_type == CallTypes.aembedding.value - ): - messages = args[1] if len(args) > 1 else kwargs["input"] - elif ( - call_type == CallTypes.image_generation.value - or call_type == CallTypes.aimage_generation.value - ): - messages = args[0] if len(args) > 0 else kwargs["prompt"] - elif ( - call_type == CallTypes.moderation.value - or call_type == CallTypes.amoderation.value - ): - messages = args[1] if len(args) > 1 else kwargs["input"] - elif ( - call_type == CallTypes.atext_completion.value - or call_type == CallTypes.text_completion.value - ): - messages = args[0] if len(args) > 0 else kwargs["prompt"] - elif ( - call_type == CallTypes.atranscription.value - or call_type == CallTypes.transcription.value - ): - _file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"] - messages = "audio_file" - stream = True if "stream" in kwargs and kwargs["stream"] == True else False - logging_obj = Logging( - model=model, - messages=messages, - stream=stream, - litellm_call_id=kwargs["litellm_call_id"], - function_id=function_id, - call_type=call_type, - start_time=start_time, - dynamic_success_callbacks=dynamic_success_callbacks, - dynamic_async_success_callbacks=dynamic_async_success_callbacks, - langfuse_public_key=kwargs.pop("langfuse_public_key", None), - langfuse_secret=kwargs.pop("langfuse_secret", None), - ) - ## check if metadata is passed in - litellm_params = {"api_base": ""} - if "metadata" in kwargs: - litellm_params["metadata"] = kwargs["metadata"] - logging_obj.update_environment_variables( - model=model, - user="", - optional_params={}, - litellm_params=litellm_params, - ) - return logging_obj, kwargs - except Exception as e: - import logging - - logging.debug( - f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}" - ) - raise e - def check_coroutine(value) -> bool: if inspect.iscoroutine(value): return True @@ -2677,7 +2669,9 @@ def client(original_function): try: if logging_obj is None: - logging_obj, kwargs = function_setup(start_time, *args, **kwargs) + logging_obj, kwargs = function_setup( + original_function, rules_obj, start_time, *args, **kwargs + ) kwargs["litellm_logging_obj"] = logging_obj # CHECK FOR 'os.environ/' in kwargs @@ -2985,7 +2979,9 @@ def client(original_function): try: if logging_obj is None: - logging_obj, kwargs = function_setup(start_time, *args, **kwargs) + logging_obj, kwargs = function_setup( + original_function, rules_obj, start_time, *args, **kwargs + ) kwargs["litellm_logging_obj"] = logging_obj # [OPTIONAL] CHECK BUDGET