updates to logging

This commit is contained in:
Krrish Dholakia 2023-09-01 14:38:50 -07:00
parent 58c15d11cc
commit 34ed4cc23c
5 changed files with 56 additions and 40 deletions

View file

@ -157,13 +157,14 @@ class CallTypes(Enum):
class Logging:
global supabaseClient, liteDebuggerClient
def __init__(self, model, messages, stream, call_type, litellm_call_id, function_id):
def __init__(self, model, messages, stream, call_type, start_time, litellm_call_id, function_id):
if call_type not in [item.value for item in CallTypes]:
allowed_values = ", ".join([item.value for item in CallTypes])
raise ValueError(f"Invalid call_type {call_type}. Allowed values: {allowed_values}")
self.model = model
self.messages = messages
self.stream = stream
self.start_time = start_time # log the call start time
self.call_type = call_type
self.litellm_call_id = litellm_call_id
self.function_id = function_id
@ -330,11 +331,15 @@ class Logging:
pass
def success_handler(self, result, start_time, end_time):
def success_handler(self, result, start_time=None, end_time=None):
print_verbose(
f"Logging Details LiteLLM-Success Call"
)
try:
if start_time is None:
start_time = self.start_time
if end_time is None:
end_time = datetime.datetime.now()
for callback in litellm.success_callback:
try:
if callback == "lite_debugger":
@ -366,11 +371,16 @@ class Logging:
)
pass
def failure_handler(self, exception, traceback_exception, start_time, end_time):
def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
print_verbose(
f"Logging Details LiteLLM-Failure Call"
)
try:
if start_time is None:
start_time = self.start_time
if end_time is None:
end_time = datetime.datetime.now()
for callback in litellm.failure_callback:
try:
if callback == "lite_debugger":
@ -451,7 +461,7 @@ def client(original_function):
global liteDebuggerClient, get_all_keys
def function_setup(
*args, **kwargs
start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
try:
global callback_list, add_breadcrumb, user_logger_fn, Logging
@ -495,7 +505,7 @@ def client(original_function):
elif call_type == CallTypes.embedding.value:
messages = args[1] if len(args) > 1 else kwargs["input"]
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type)
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time)
return logging_obj
except: # DO NOT BLOCK running the function because of this
print_verbose(f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}")
@ -521,14 +531,13 @@ def client(original_function):
pass
def wrapper(*args, **kwargs):
start_time = None
start_time = datetime.datetime.now()
result = None
litellm_call_id = str(uuid.uuid4())
kwargs["litellm_call_id"] = litellm_call_id
logging_obj = function_setup(*args, **kwargs)
logging_obj = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj
try:
start_time = datetime.datetime.now()
# [OPTIONAL] CHECK CACHE
# remove this after deprecating litellm.caching
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
@ -543,12 +552,11 @@ def client(original_function):
# MODEL CALL
result = original_function(*args, **kwargs)
end_time = datetime.datetime.now()
# LOG SUCCESS
logging_obj.success_handler(result, start_time, end_time)
if "stream" in kwargs and kwargs["stream"] == True:
# TODO: Add to cache for streaming
return result
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
@ -557,7 +565,8 @@ def client(original_function):
if litellm.use_client == True:
result['litellm_call_id'] = litellm_call_id
# LOG SUCCESS
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
my_thread = threading.Thread(
target=handle_success, args=(args, kwargs, result, start_time, end_time)
) # don't interrupt execution of main thread
@ -568,7 +577,8 @@ def client(original_function):
traceback_exception = traceback.format_exc()
crash_reporting(*args, **kwargs, exception=traceback_exception)
end_time = datetime.datetime.now()
logging_obj.failure_handler(e, traceback_exception, start_time, end_time)
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(target=logging_obj.failure_handler, args=(e, traceback_exception, start_time, end_time)).start()
my_thread = threading.Thread(
target=handle_failure,
args=(e, traceback_exception, start_time, end_time, args, kwargs),
@ -1833,7 +1843,7 @@ class CustomStreamWrapper:
completion_obj["content"] = self.handle_openai_chat_completion_chunk(chunk)
# LOGGING
# self.logging_obj.post_call(completion_obj["content"])
threading.Thread(target=self.logging_obj.success_handler, args=(completion_obj,)).start()
# return this for all models
return {"choices": [{"delta": completion_obj}]}
except:
@ -1933,7 +1943,7 @@ def get_model_split_test(models, completion_call_id):
)
def completion_with_split_tests(models={}, messages=[], use_client=False, **kwargs):
def completion_with_split_tests(models={}, messages=[], use_client=False, override_client=False, **kwargs):
"""
Example Usage:
@ -1945,7 +1955,7 @@ def completion_with_split_tests(models={}, messages=[], use_client=False, **kwar
completion_with_split_tests(models=models, messages=messages)
"""
import random
if use_client:
if use_client and not override_client:
if "id" not in kwargs or kwargs["id"] is None:
raise ValueError("Please tag this completion call, if you'd like to update it's split test values through the UI. - eg. `completion_with_split_tests(.., id=1234)`.")
# get the most recent model split list from server