forked from phoenix/litellm-mirror
updates to logging
This commit is contained in:
parent
58c15d11cc
commit
34ed4cc23c
5 changed files with 56 additions and 40 deletions
|
@ -157,13 +157,14 @@ class CallTypes(Enum):
|
|||
class Logging:
|
||||
global supabaseClient, liteDebuggerClient
|
||||
|
||||
def __init__(self, model, messages, stream, call_type, litellm_call_id, function_id):
|
||||
def __init__(self, model, messages, stream, call_type, start_time, litellm_call_id, function_id):
|
||||
if call_type not in [item.value for item in CallTypes]:
|
||||
allowed_values = ", ".join([item.value for item in CallTypes])
|
||||
raise ValueError(f"Invalid call_type {call_type}. Allowed values: {allowed_values}")
|
||||
self.model = model
|
||||
self.messages = messages
|
||||
self.stream = stream
|
||||
self.start_time = start_time # log the call start time
|
||||
self.call_type = call_type
|
||||
self.litellm_call_id = litellm_call_id
|
||||
self.function_id = function_id
|
||||
|
@ -330,11 +331,15 @@ class Logging:
|
|||
pass
|
||||
|
||||
|
||||
def success_handler(self, result, start_time, end_time):
|
||||
def success_handler(self, result, start_time=None, end_time=None):
|
||||
print_verbose(
|
||||
f"Logging Details LiteLLM-Success Call"
|
||||
)
|
||||
try:
|
||||
if start_time is None:
|
||||
start_time = self.start_time
|
||||
if end_time is None:
|
||||
end_time = datetime.datetime.now()
|
||||
for callback in litellm.success_callback:
|
||||
try:
|
||||
if callback == "lite_debugger":
|
||||
|
@ -366,11 +371,16 @@ class Logging:
|
|||
)
|
||||
pass
|
||||
|
||||
def failure_handler(self, exception, traceback_exception, start_time, end_time):
|
||||
def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
|
||||
print_verbose(
|
||||
f"Logging Details LiteLLM-Failure Call"
|
||||
)
|
||||
try:
|
||||
if start_time is None:
|
||||
start_time = self.start_time
|
||||
if end_time is None:
|
||||
end_time = datetime.datetime.now()
|
||||
|
||||
for callback in litellm.failure_callback:
|
||||
try:
|
||||
if callback == "lite_debugger":
|
||||
|
@ -451,7 +461,7 @@ def client(original_function):
|
|||
global liteDebuggerClient, get_all_keys
|
||||
|
||||
def function_setup(
|
||||
*args, **kwargs
|
||||
start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
try:
|
||||
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
||||
|
@ -495,7 +505,7 @@ def client(original_function):
|
|||
elif call_type == CallTypes.embedding.value:
|
||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
|
||||
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type)
|
||||
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time)
|
||||
return logging_obj
|
||||
except: # DO NOT BLOCK running the function because of this
|
||||
print_verbose(f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}")
|
||||
|
@ -521,14 +531,13 @@ def client(original_function):
|
|||
pass
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = None
|
||||
start_time = datetime.datetime.now()
|
||||
result = None
|
||||
litellm_call_id = str(uuid.uuid4())
|
||||
kwargs["litellm_call_id"] = litellm_call_id
|
||||
logging_obj = function_setup(*args, **kwargs)
|
||||
logging_obj = function_setup(start_time, *args, **kwargs)
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
try:
|
||||
start_time = datetime.datetime.now()
|
||||
# [OPTIONAL] CHECK CACHE
|
||||
# remove this after deprecating litellm.caching
|
||||
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
|
||||
|
@ -543,12 +552,11 @@ def client(original_function):
|
|||
# MODEL CALL
|
||||
result = original_function(*args, **kwargs)
|
||||
end_time = datetime.datetime.now()
|
||||
# LOG SUCCESS
|
||||
logging_obj.success_handler(result, start_time, end_time)
|
||||
|
||||
if "stream" in kwargs and kwargs["stream"] == True:
|
||||
# TODO: Add to cache for streaming
|
||||
return result
|
||||
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
litellm.cache.add_cache(result, *args, **kwargs)
|
||||
|
@ -557,7 +565,8 @@ def client(original_function):
|
|||
if litellm.use_client == True:
|
||||
result['litellm_call_id'] = litellm_call_id
|
||||
|
||||
# LOG SUCCESS
|
||||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
||||
my_thread = threading.Thread(
|
||||
target=handle_success, args=(args, kwargs, result, start_time, end_time)
|
||||
) # don't interrupt execution of main thread
|
||||
|
@ -568,7 +577,8 @@ def client(original_function):
|
|||
traceback_exception = traceback.format_exc()
|
||||
crash_reporting(*args, **kwargs, exception=traceback_exception)
|
||||
end_time = datetime.datetime.now()
|
||||
logging_obj.failure_handler(e, traceback_exception, start_time, end_time)
|
||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||
threading.Thread(target=logging_obj.failure_handler, args=(e, traceback_exception, start_time, end_time)).start()
|
||||
my_thread = threading.Thread(
|
||||
target=handle_failure,
|
||||
args=(e, traceback_exception, start_time, end_time, args, kwargs),
|
||||
|
@ -1833,7 +1843,7 @@ class CustomStreamWrapper:
|
|||
completion_obj["content"] = self.handle_openai_chat_completion_chunk(chunk)
|
||||
|
||||
# LOGGING
|
||||
# self.logging_obj.post_call(completion_obj["content"])
|
||||
threading.Thread(target=self.logging_obj.success_handler, args=(completion_obj,)).start()
|
||||
# return this for all models
|
||||
return {"choices": [{"delta": completion_obj}]}
|
||||
except:
|
||||
|
@ -1933,7 +1943,7 @@ def get_model_split_test(models, completion_call_id):
|
|||
)
|
||||
|
||||
|
||||
def completion_with_split_tests(models={}, messages=[], use_client=False, **kwargs):
|
||||
def completion_with_split_tests(models={}, messages=[], use_client=False, override_client=False, **kwargs):
|
||||
"""
|
||||
Example Usage:
|
||||
|
||||
|
@ -1945,7 +1955,7 @@ def completion_with_split_tests(models={}, messages=[], use_client=False, **kwar
|
|||
completion_with_split_tests(models=models, messages=messages)
|
||||
"""
|
||||
import random
|
||||
if use_client:
|
||||
if use_client and not override_client:
|
||||
if "id" not in kwargs or kwargs["id"] is None:
|
||||
raise ValueError("Please tag this completion call, if you'd like to update it's split test values through the UI. - eg. `completion_with_split_tests(.., id=1234)`.")
|
||||
# get the most recent model split list from server
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue