updates to logging

This commit is contained in:
Krrish Dholakia 2023-09-01 14:38:50 -07:00
parent 58c15d11cc
commit 34ed4cc23c
5 changed files with 56 additions and 40 deletions

View file

@ -184,10 +184,11 @@ class LiteDebugger:
data=json.dumps(litellm_data_obj), data=json.dumps(litellm_data_obj),
) )
elif call_type == "completion" and stream == True: elif call_type == "completion" and stream == True:
if len(response_obj["content"]) > 0: # don't log the empty strings
litellm_data_obj = { litellm_data_obj = {
"response_time": response_time, "response_time": response_time,
"total_cost": total_cost, "total_cost": total_cost,
"response": "streamed response", "response": response_obj["content"],
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
"status": "success", "status": "success",
} }

View file

@ -36,8 +36,8 @@ litellm.set_verbose = True
score = 0 score = 0
split_per_model = { split_per_model = {
"gpt-4": 0.7, "gpt-4": 0,
"claude-instant-1.2": 0.3 "claude-instant-1.2": 1
} }
@ -81,26 +81,31 @@ try:
raise Exception("LiteLLMDebugger: post-api call not logged!") raise Exception("LiteLLMDebugger: post-api call not logged!")
if "LiteDebugger: Success/Failure Call Logging" not in output: if "LiteDebugger: Success/Failure Call Logging" not in output:
raise Exception("LiteLLMDebugger: success/failure call not logged!") raise Exception("LiteLLMDebugger: success/failure call not logged!")
except: except Exception as e:
pass pytest.fail(f"Error occurred: {e}")
# Test 3: On streaming completion call - setting client to true Test 3: On streaming completion call - setting client to true
try: try:
# Redirect stdout # Redirect stdout
old_stdout = sys.stdout old_stdout = sys.stdout
sys.stdout = new_stdout = io.StringIO() sys.stdout = new_stdout = io.StringIO()
response = completion_with_split_tests(models=split_per_model, messages=messages, stream=True, use_client=True, id="6d383c99-488d-481d-aa1b-1f94935cec44") response = completion_with_split_tests(models=split_per_model, messages=messages, stream=True, use_client=True, override_client=True, id="6d383c99-488d-481d-aa1b-1f94935cec44")
for data in response:
print(data)
# Restore stdout # Restore stdout
sys.stdout = old_stdout sys.stdout = old_stdout
output = new_stdout.getvalue().strip() output = new_stdout.getvalue().strip()
print(output)
print(f"response: {response}")
if "LiteDebugger: Pre-API Call Logging" not in output: if "LiteDebugger: Pre-API Call Logging" not in output:
raise Exception("LiteLLMDebugger: pre-api call not logged!") raise Exception("LiteLLMDebugger: pre-api call not logged!")
if "LiteDebugger: Post-API Call Logging" not in output: if "LiteDebugger: Post-API Call Logging" not in output:
raise Exception("LiteLLMDebugger: post-api call not logged!") raise Exception("LiteLLMDebugger: post-api call not logged!")
if "LiteDebugger: Success/Failure Call Logging" not in output: if "LiteDebugger: Success/Failure Call Logging" not in output:
raise Exception("LiteLLMDebugger: success/failure call not logged!") raise Exception("LiteLLMDebugger: success/failure call not logged!")
except: except Exception as e:
pass pytest.fail(f"Error occurred: {e}")

View file

@ -157,13 +157,14 @@ class CallTypes(Enum):
class Logging: class Logging:
global supabaseClient, liteDebuggerClient global supabaseClient, liteDebuggerClient
def __init__(self, model, messages, stream, call_type, litellm_call_id, function_id): def __init__(self, model, messages, stream, call_type, start_time, litellm_call_id, function_id):
if call_type not in [item.value for item in CallTypes]: if call_type not in [item.value for item in CallTypes]:
allowed_values = ", ".join([item.value for item in CallTypes]) allowed_values = ", ".join([item.value for item in CallTypes])
raise ValueError(f"Invalid call_type {call_type}. Allowed values: {allowed_values}") raise ValueError(f"Invalid call_type {call_type}. Allowed values: {allowed_values}")
self.model = model self.model = model
self.messages = messages self.messages = messages
self.stream = stream self.stream = stream
self.start_time = start_time # log the call start time
self.call_type = call_type self.call_type = call_type
self.litellm_call_id = litellm_call_id self.litellm_call_id = litellm_call_id
self.function_id = function_id self.function_id = function_id
@ -330,11 +331,15 @@ class Logging:
pass pass
def success_handler(self, result, start_time, end_time): def success_handler(self, result, start_time=None, end_time=None):
print_verbose( print_verbose(
f"Logging Details LiteLLM-Success Call" f"Logging Details LiteLLM-Success Call"
) )
try: try:
if start_time is None:
start_time = self.start_time
if end_time is None:
end_time = datetime.datetime.now()
for callback in litellm.success_callback: for callback in litellm.success_callback:
try: try:
if callback == "lite_debugger": if callback == "lite_debugger":
@ -366,11 +371,16 @@ class Logging:
) )
pass pass
def failure_handler(self, exception, traceback_exception, start_time, end_time): def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
print_verbose( print_verbose(
f"Logging Details LiteLLM-Failure Call" f"Logging Details LiteLLM-Failure Call"
) )
try: try:
if start_time is None:
start_time = self.start_time
if end_time is None:
end_time = datetime.datetime.now()
for callback in litellm.failure_callback: for callback in litellm.failure_callback:
try: try:
if callback == "lite_debugger": if callback == "lite_debugger":
@ -451,7 +461,7 @@ def client(original_function):
global liteDebuggerClient, get_all_keys global liteDebuggerClient, get_all_keys
def function_setup( def function_setup(
*args, **kwargs start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
try: try:
global callback_list, add_breadcrumb, user_logger_fn, Logging global callback_list, add_breadcrumb, user_logger_fn, Logging
@ -495,7 +505,7 @@ def client(original_function):
elif call_type == CallTypes.embedding.value: elif call_type == CallTypes.embedding.value:
messages = args[1] if len(args) > 1 else kwargs["input"] messages = args[1] if len(args) > 1 else kwargs["input"]
stream = True if "stream" in kwargs and kwargs["stream"] == True else False stream = True if "stream" in kwargs and kwargs["stream"] == True else False
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type) logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time)
return logging_obj return logging_obj
except: # DO NOT BLOCK running the function because of this except: # DO NOT BLOCK running the function because of this
print_verbose(f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}") print_verbose(f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}")
@ -521,14 +531,13 @@ def client(original_function):
pass pass
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
start_time = None start_time = datetime.datetime.now()
result = None result = None
litellm_call_id = str(uuid.uuid4()) litellm_call_id = str(uuid.uuid4())
kwargs["litellm_call_id"] = litellm_call_id kwargs["litellm_call_id"] = litellm_call_id
logging_obj = function_setup(*args, **kwargs) logging_obj = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj kwargs["litellm_logging_obj"] = logging_obj
try: try:
start_time = datetime.datetime.now()
# [OPTIONAL] CHECK CACHE # [OPTIONAL] CHECK CACHE
# remove this after deprecating litellm.caching # remove this after deprecating litellm.caching
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None: if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
@ -543,12 +552,11 @@ def client(original_function):
# MODEL CALL # MODEL CALL
result = original_function(*args, **kwargs) result = original_function(*args, **kwargs)
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
# LOG SUCCESS
logging_obj.success_handler(result, start_time, end_time)
if "stream" in kwargs and kwargs["stream"] == True: if "stream" in kwargs and kwargs["stream"] == True:
# TODO: Add to cache for streaming # TODO: Add to cache for streaming
return result return result
# [OPTIONAL] ADD TO CACHE # [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs) litellm.cache.add_cache(result, *args, **kwargs)
@ -557,7 +565,8 @@ def client(original_function):
if litellm.use_client == True: if litellm.use_client == True:
result['litellm_call_id'] = litellm_call_id result['litellm_call_id'] = litellm_call_id
# LOG SUCCESS # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
my_thread = threading.Thread( my_thread = threading.Thread(
target=handle_success, args=(args, kwargs, result, start_time, end_time) target=handle_success, args=(args, kwargs, result, start_time, end_time)
) # don't interrupt execution of main thread ) # don't interrupt execution of main thread
@ -568,7 +577,8 @@ def client(original_function):
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
crash_reporting(*args, **kwargs, exception=traceback_exception) crash_reporting(*args, **kwargs, exception=traceback_exception)
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(target=logging_obj.failure_handler, args=(e, traceback_exception, start_time, end_time)).start()
my_thread = threading.Thread( my_thread = threading.Thread(
target=handle_failure, target=handle_failure,
args=(e, traceback_exception, start_time, end_time, args, kwargs), args=(e, traceback_exception, start_time, end_time, args, kwargs),
@ -1833,7 +1843,7 @@ class CustomStreamWrapper:
completion_obj["content"] = self.handle_openai_chat_completion_chunk(chunk) completion_obj["content"] = self.handle_openai_chat_completion_chunk(chunk)
# LOGGING # LOGGING
# self.logging_obj.post_call(completion_obj["content"]) threading.Thread(target=self.logging_obj.success_handler, args=(completion_obj,)).start()
# return this for all models # return this for all models
return {"choices": [{"delta": completion_obj}]} return {"choices": [{"delta": completion_obj}]}
except: except:
@ -1933,7 +1943,7 @@ def get_model_split_test(models, completion_call_id):
) )
def completion_with_split_tests(models={}, messages=[], use_client=False, **kwargs): def completion_with_split_tests(models={}, messages=[], use_client=False, override_client=False, **kwargs):
""" """
Example Usage: Example Usage:
@ -1945,7 +1955,7 @@ def completion_with_split_tests(models={}, messages=[], use_client=False, **kwar
completion_with_split_tests(models=models, messages=messages) completion_with_split_tests(models=models, messages=messages)
""" """
import random import random
if use_client: if use_client and not override_client:
if "id" not in kwargs or kwargs["id"] is None: if "id" not in kwargs or kwargs["id"] is None:
raise ValueError("Please tag this completion call, if you'd like to update it's split test values through the UI. - eg. `completion_with_split_tests(.., id=1234)`.") raise ValueError("Please tag this completion call, if you'd like to update it's split test values through the UI. - eg. `completion_with_split_tests(.., id=1234)`.")
# get the most recent model split list from server # get the most recent model split list from server