fix(utils.py): function_setup empty message fix

fixes https://github.com/BerriAI/litellm/issues/2858
This commit is contained in:
Krrish Dholakia 2024-04-18 07:32:29 -07:00
parent b38c09c87f
commit 6eb8fe35c8
2 changed files with 231 additions and 202 deletions

View file

@ -2381,210 +2381,202 @@ class Rules:
####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def function_setup(
original_function, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
try:
global callback_list, add_breadcrumb, user_logger_fn, Logging
function_id = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
if callback not in litellm.success_callback:
litellm.success_callback.append(callback)
if callback not in litellm.failure_callback:
litellm.failure_callback.append(callback)
if callback not in litellm._async_success_callback:
litellm._async_success_callback.append(callback)
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback)
print_verbose(
f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}"
)
if (
len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0
or len(litellm.failure_callback) > 0
) and len(callback_list) == 0:
callback_list = list(
set(
litellm.input_callback
+ litellm.success_callback
+ litellm.failure_callback
)
)
set_callbacks(callback_list=callback_list, function_id=function_id)
## ASYNC CALLBACKS
if len(litellm.input_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.input_callback):
if inspect.iscoroutinefunction(callback):
litellm._async_input_callback.append(callback)
removed_async_items.append(index)
# Pop the async items from input_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.input_callback.pop(index)
if len(litellm.success_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.success_callback):
if inspect.iscoroutinefunction(callback):
litellm._async_success_callback.append(callback)
removed_async_items.append(index)
elif callback == "dynamodb":
# dynamo is an async callback, it's used for the proxy and needs to be async
# we only support async dynamo db logging for acompletion/aembedding since that's used on proxy
litellm._async_success_callback.append(callback)
removed_async_items.append(index)
# Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.success_callback.pop(index)
if len(litellm.failure_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.failure_callback):
if inspect.iscoroutinefunction(callback):
litellm._async_failure_callback.append(callback)
removed_async_items.append(index)
# Pop the async items from failure_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.failure_callback.pop(index)
### DYNAMIC CALLBACKS ###
dynamic_success_callbacks = None
dynamic_async_success_callbacks = None
if kwargs.get("success_callback", None) is not None and isinstance(
kwargs["success_callback"], list
):
removed_async_items = []
for index, callback in enumerate(kwargs["success_callback"]):
if (
inspect.iscoroutinefunction(callback)
or callback == "dynamodb"
or callback == "s3"
):
if dynamic_async_success_callbacks is not None and isinstance(
dynamic_async_success_callbacks, list
):
dynamic_async_success_callbacks.append(callback)
else:
dynamic_async_success_callbacks = [callback]
removed_async_items.append(index)
# Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
kwargs["success_callback"].pop(index)
dynamic_success_callbacks = kwargs.pop("success_callback")
if add_breadcrumb:
add_breadcrumb(
category="litellm.llm_call",
message=f"Positional Args: {args}, Keyword Args: {kwargs}",
level="info",
)
if "logger_fn" in kwargs:
user_logger_fn = kwargs["logger_fn"]
# INIT LOGGER - for user-specified integrations
model = args[0] if len(args) > 0 else kwargs.get("model", None)
call_type = original_function.__name__
if (
call_type == CallTypes.completion.value
or call_type == CallTypes.acompletion.value
):
messages = None
if len(args) > 1:
messages = args[1]
elif kwargs.get("messages", None):
messages = kwargs["messages"]
### PRE-CALL RULES ###
if (
isinstance(messages, list)
and len(messages) > 0
and isinstance(messages[0], dict)
and "content" in messages[0]
):
rules_obj.pre_call_rules(
input="".join(
m.get("content", "")
for m in messages
if "content" in m and isinstance(m["content"], str)
),
model=model,
)
elif (
call_type == CallTypes.embedding.value
or call_type == CallTypes.aembedding.value
):
messages = args[1] if len(args) > 1 else kwargs["input"]
elif (
call_type == CallTypes.image_generation.value
or call_type == CallTypes.aimage_generation.value
):
messages = args[0] if len(args) > 0 else kwargs["prompt"]
elif (
call_type == CallTypes.moderation.value
or call_type == CallTypes.amoderation.value
):
messages = args[1] if len(args) > 1 else kwargs["input"]
elif (
call_type == CallTypes.atext_completion.value
or call_type == CallTypes.text_completion.value
):
messages = args[0] if len(args) > 0 else kwargs["prompt"]
elif (
call_type == CallTypes.atranscription.value
or call_type == CallTypes.transcription.value
):
_file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"]
messages = "audio_file"
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
logging_obj = Logging(
model=model,
messages=messages,
stream=stream,
litellm_call_id=kwargs["litellm_call_id"],
function_id=function_id,
call_type=call_type,
start_time=start_time,
dynamic_success_callbacks=dynamic_success_callbacks,
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
langfuse_public_key=kwargs.pop("langfuse_public_key", None),
langfuse_secret=kwargs.pop("langfuse_secret", None),
)
## check if metadata is passed in
litellm_params = {"api_base": ""}
if "metadata" in kwargs:
litellm_params["metadata"] = kwargs["metadata"]
logging_obj.update_environment_variables(
model=model,
user="",
optional_params={},
litellm_params=litellm_params,
)
return logging_obj, kwargs
except Exception as e:
import logging
logging.debug(
f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}"
)
raise e
def client(original_function):
global liteDebuggerClient, get_all_keys
rules_obj = Rules()
def function_setup(
start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
try:
global callback_list, add_breadcrumb, user_logger_fn, Logging
function_id = kwargs["id"] if "id" in kwargs else None
if litellm.use_client or (
"use_client" in kwargs and kwargs["use_client"] == True
):
if "lite_debugger" not in litellm.input_callback:
litellm.input_callback.append("lite_debugger")
if "lite_debugger" not in litellm.success_callback:
litellm.success_callback.append("lite_debugger")
if "lite_debugger" not in litellm.failure_callback:
litellm.failure_callback.append("lite_debugger")
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
if callback not in litellm.success_callback:
litellm.success_callback.append(callback)
if callback not in litellm.failure_callback:
litellm.failure_callback.append(callback)
if callback not in litellm._async_success_callback:
litellm._async_success_callback.append(callback)
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback)
print_verbose(
f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}"
)
if (
len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0
or len(litellm.failure_callback) > 0
) and len(callback_list) == 0:
callback_list = list(
set(
litellm.input_callback
+ litellm.success_callback
+ litellm.failure_callback
)
)
set_callbacks(callback_list=callback_list, function_id=function_id)
## ASYNC CALLBACKS
if len(litellm.input_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.input_callback):
if inspect.iscoroutinefunction(callback):
litellm._async_input_callback.append(callback)
removed_async_items.append(index)
# Pop the async items from input_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.input_callback.pop(index)
if len(litellm.success_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.success_callback):
if inspect.iscoroutinefunction(callback):
litellm._async_success_callback.append(callback)
removed_async_items.append(index)
elif callback == "dynamodb":
# dynamo is an async callback, it's used for the proxy and needs to be async
# we only support async dynamo db logging for acompletion/aembedding since that's used on proxy
litellm._async_success_callback.append(callback)
removed_async_items.append(index)
# Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.success_callback.pop(index)
if len(litellm.failure_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.failure_callback):
if inspect.iscoroutinefunction(callback):
litellm._async_failure_callback.append(callback)
removed_async_items.append(index)
# Pop the async items from failure_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.failure_callback.pop(index)
### DYNAMIC CALLBACKS ###
dynamic_success_callbacks = None
dynamic_async_success_callbacks = None
if kwargs.get("success_callback", None) is not None and isinstance(
kwargs["success_callback"], list
):
removed_async_items = []
for index, callback in enumerate(kwargs["success_callback"]):
if (
inspect.iscoroutinefunction(callback)
or callback == "dynamodb"
or callback == "s3"
):
if dynamic_async_success_callbacks is not None and isinstance(
dynamic_async_success_callbacks, list
):
dynamic_async_success_callbacks.append(callback)
else:
dynamic_async_success_callbacks = [callback]
removed_async_items.append(index)
# Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
kwargs["success_callback"].pop(index)
dynamic_success_callbacks = kwargs.pop("success_callback")
if add_breadcrumb:
add_breadcrumb(
category="litellm.llm_call",
message=f"Positional Args: {args}, Keyword Args: {kwargs}",
level="info",
)
if "logger_fn" in kwargs:
user_logger_fn = kwargs["logger_fn"]
# INIT LOGGER - for user-specified integrations
model = args[0] if len(args) > 0 else kwargs.get("model", None)
call_type = original_function.__name__
if (
call_type == CallTypes.completion.value
or call_type == CallTypes.acompletion.value
):
messages = None
if len(args) > 1:
messages = args[1]
elif kwargs.get("messages", None):
messages = kwargs["messages"]
### PRE-CALL RULES ###
if (
isinstance(messages, list)
and len(messages) > 0
and isinstance(messages[0], dict)
and "content" in messages[0]
):
rules_obj.pre_call_rules(
input="".join(
m.get("content", "")
for m in messages
if isinstance(m["content"], str)
),
model=model,
)
elif (
call_type == CallTypes.embedding.value
or call_type == CallTypes.aembedding.value
):
messages = args[1] if len(args) > 1 else kwargs["input"]
elif (
call_type == CallTypes.image_generation.value
or call_type == CallTypes.aimage_generation.value
):
messages = args[0] if len(args) > 0 else kwargs["prompt"]
elif (
call_type == CallTypes.moderation.value
or call_type == CallTypes.amoderation.value
):
messages = args[1] if len(args) > 1 else kwargs["input"]
elif (
call_type == CallTypes.atext_completion.value
or call_type == CallTypes.text_completion.value
):
messages = args[0] if len(args) > 0 else kwargs["prompt"]
elif (
call_type == CallTypes.atranscription.value
or call_type == CallTypes.transcription.value
):
_file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"]
messages = "audio_file"
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
logging_obj = Logging(
model=model,
messages=messages,
stream=stream,
litellm_call_id=kwargs["litellm_call_id"],
function_id=function_id,
call_type=call_type,
start_time=start_time,
dynamic_success_callbacks=dynamic_success_callbacks,
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
langfuse_public_key=kwargs.pop("langfuse_public_key", None),
langfuse_secret=kwargs.pop("langfuse_secret", None),
)
## check if metadata is passed in
litellm_params = {"api_base": ""}
if "metadata" in kwargs:
litellm_params["metadata"] = kwargs["metadata"]
logging_obj.update_environment_variables(
model=model,
user="",
optional_params={},
litellm_params=litellm_params,
)
return logging_obj, kwargs
except Exception as e:
import logging
logging.debug(
f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}"
)
raise e
def check_coroutine(value) -> bool:
if inspect.iscoroutine(value):
return True
@ -2677,7 +2669,9 @@ def client(original_function):
try:
if logging_obj is None:
logging_obj, kwargs = function_setup(start_time, *args, **kwargs)
logging_obj, kwargs = function_setup(
original_function, rules_obj, start_time, *args, **kwargs
)
kwargs["litellm_logging_obj"] = logging_obj
# CHECK FOR 'os.environ/' in kwargs
@ -2985,7 +2979,9 @@ def client(original_function):
try:
if logging_obj is None:
logging_obj, kwargs = function_setup(start_time, *args, **kwargs)
logging_obj, kwargs = function_setup(
original_function, rules_obj, start_time, *args, **kwargs
)
kwargs["litellm_logging_obj"] = logging_obj
# [OPTIONAL] CHECK BUDGET