fixing exception mapping

This commit is contained in:
Krrish Dholakia 2023-08-05 09:52:01 -07:00
parent 9b0e9bf57c
commit 92a13958ce
8 changed files with 188 additions and 115 deletions

View file

@ -25,41 +25,44 @@ def print_verbose(print_statement):
####### LOGGING ###################
#Logging function -> log the exact model details + what's being sent | Non-Blocking
def logging(model, input, azure=False, additional_args={}, logger_fn=None, exception=None):
def logging(model=None, input=None, azure=False, additional_args={}, logger_fn=None, exception=None):
try:
model_call_details = {}
model_call_details["model"] = model
model_call_details["azure"] = azure
# log exception details
if model:
model_call_details["model"] = model
if azure:
model_call_details["azure"] = azure
if exception:
model_call_details["original_exception"] = exception
if litellm.telemetry:
safe_crash_reporting(model=model, exception=exception, azure=azure) # log usage-crash details. Do not log any user details. If you want to turn this off, set `litellm.telemetry=False`.
model_call_details["input"] = input
if input:
model_call_details["input"] = input
# log additional call details -> api key, etc.
if azure == True or model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_embedding_models:
model_call_details["api_type"] = openai.api_type
model_call_details["api_base"] = openai.api_base
model_call_details["api_version"] = openai.api_version
model_call_details["api_key"] = openai.api_key
elif "replicate" in model:
model_call_details["api_key"] = os.environ.get("REPLICATE_API_TOKEN")
elif model in litellm.anthropic_models:
model_call_details["api_key"] = os.environ.get("ANTHROPIC_API_KEY")
elif model in litellm.cohere_models:
model_call_details["api_key"] = os.environ.get("COHERE_API_KEY")
model_call_details["additional_args"] = additional_args
if model:
if azure == True or model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_embedding_models:
model_call_details["api_type"] = openai.api_type
model_call_details["api_base"] = openai.api_base
model_call_details["api_version"] = openai.api_version
model_call_details["api_key"] = openai.api_key
elif "replicate" in model:
model_call_details["api_key"] = os.environ.get("REPLICATE_API_TOKEN")
elif model in litellm.anthropic_models:
model_call_details["api_key"] = os.environ.get("ANTHROPIC_API_KEY")
elif model in litellm.cohere_models:
model_call_details["api_key"] = os.environ.get("COHERE_API_KEY")
model_call_details["additional_args"] = additional_args
## User Logging -> if you pass in a custom logging function or want to use sentry breadcrumbs
print_verbose(f"Basic model call details: {model_call_details}")
print_verbose(f"Logging Details: logger_fn - {logger_fn} | callable(logger_fn) - {callable(logger_fn)}")
if logger_fn and callable(logger_fn):
try:
logger_fn(model_call_details) # Expectation: any logger function passed in by the user should accept a dict object
except:
print_verbose(f"[Non-Blocking] Exception occurred while logging {traceback.format_exc()}")
except:
traceback.print_exc()
except Exception as e:
print(f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}")
except Exception as e:
print(f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}")
pass
####### CLIENT ###################
@ -67,7 +70,7 @@ def logging(model, input, azure=False, additional_args={}, logger_fn=None, excep
def client(original_function):
def function_setup(*args, **kwargs): #just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
try:
global callback_list, add_breadcrumb
global callback_list, add_breadcrumb, user_logger_fn
if (len(litellm.success_callback) > 0 or len(litellm.failure_callback) > 0) and len(callback_list) == 0:
callback_list = list(set(litellm.success_callback + litellm.failure_callback))
set_callbacks(callback_list=callback_list,)
@ -77,13 +80,15 @@ def client(original_function):
message=f"Positional Args: {args}, Keyword Args: {kwargs}",
level="info",
)
if "logger_fn" in kwargs:
user_logger_fn = kwargs["logger_fn"]
except: # DO NOT BLOCK running the function because of this
print_verbose(f"[Non-Blocking] {traceback.format_exc()}")
pass
def wrapper(*args, **kwargs):
try:
function_setup(args, kwargs)
function_setup(*args, **kwargs)
## MODEL CALL
start_time = datetime.datetime.now()
result = original_function(*args, **kwargs)
@ -100,6 +105,51 @@ def client(original_function):
return wrapper
####### HELPER FUNCTIONS ################
def get_optional_params(
# 12 optional params
functions = [],
function_call = "",
temperature = 1,
top_p = 1,
n = 1,
stream = False,
stop = None,
max_tokens = float('inf'),
presence_penalty = 0,
frequency_penalty = 0,
logit_bias = {},
user = "",
deployment_id = None
):
optional_params = {}
if functions != []:
optional_params["functions"] = functions
if function_call != "":
optional_params["function_call"] = function_call
if temperature != 1:
optional_params["temperature"] = temperature
if top_p != 1:
optional_params["top_p"] = top_p
if n != 1:
optional_params["n"] = n
if stream:
optional_params["stream"] = stream
if stop != None:
optional_params["stop"] = stop
if max_tokens != float('inf'):
optional_params["max_tokens"] = max_tokens
if presence_penalty != 0:
optional_params["presence_penalty"] = presence_penalty
if frequency_penalty != 0:
optional_params["frequency_penalty"] = frequency_penalty
if logit_bias != {}:
optional_params["logit_bias"] = logit_bias
if user != "":
optional_params["user"] = user
if deployment_id != None:
optional_params["deployment_id"] = user
return optional_params
def set_callbacks(callback_list):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger
try:
@ -150,8 +200,8 @@ def set_callbacks(callback_list):
def handle_failure(exception, traceback_exception, args, kwargs):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel
try:
print_verbose(f"handle_failure args: {args}")
print_verbose(f"handle_failure kwargs: {kwargs}")
# print_verbose(f"handle_failure args: {args}")
# print_verbose(f"handle_failure kwargs: {kwargs}")
success_handler = additional_details.pop("success_handler", None)
failure_handler = additional_details.pop("failure_handler", None)
@ -159,7 +209,8 @@ def handle_failure(exception, traceback_exception, args, kwargs):
additional_details["Event_Name"] = additional_details.pop("failed_event_name", "litellm.failed_query")
print_verbose(f"self.failure_callback: {litellm.failure_callback}")
print_verbose(f"additional_details: {additional_details}")
# print_verbose(f"additional_details: {additional_details}")
for callback in litellm.failure_callback:
try:
if callback == "slack":
@ -206,7 +257,9 @@ def handle_failure(exception, traceback_exception, args, kwargs):
}
failure_handler(call_details)
pass
except:
except Exception as e:
## LOGGING
logging(logger_fn=user_logger_fn, exception=e)
pass
def handle_success(args, kwargs, result, start_time, end_time):
@ -245,12 +298,16 @@ def handle_success(args, kwargs, result, start_time, end_time):
if success_handler and callable(success_handler):
success_handler(args, kwargs)
pass
except:
except Exception as e:
## LOGGING
logging(logger_fn=user_logger_fn, exception=e)
print_verbose(f"Success Callback Error - {traceback.format_exc()}")
pass
def exception_type(model, original_exception):
global user_logger_fn
exception_mapping_worked = False
try:
if isinstance(original_exception, OpenAIError):
# Handle the OpenAIError
@ -265,32 +322,46 @@ def exception_type(model, original_exception):
if "status_code" in original_exception:
print_verbose(f"status_code: {original_exception.status_code}")
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(f"AnthropicException - {original_exception.message}")
elif original_exception.status_code == 400:
exception_mapping_worked = True
raise InvalidRequestError(f"AnthropicException - {original_exception.message}", f"{model}")
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(f"AnthropicException - {original_exception.message}")
elif "replicate" in model:
if "Incorrect authentication token" in error_str:
exception_mapping_worked = True
raise AuthenticationError(f"ReplicateException - {error_str}")
elif exception_type == "ModelError":
exception_mapping_worked = True
raise InvalidRequestError(f"ReplicateException - {error_str}", f"{model}")
elif "Request was throttled" in error_str:
exception_mapping_worked = True
raise RateLimitError(f"ReplicateException - {error_str}")
elif exception_type == "ReplicateError": ## ReplicateError implies an error on Replicate server side, not user side
raise ServiceUnavailableError(f"ReplicateException - {error_str}")
elif model == "command-nightly": #Cohere
if "invalid api token" in error_str or "No API key provided." in error_str:
exception_mapping_worked = True
raise AuthenticationError(f"CohereException - {error_str}")
elif "too many tokens" in error_str:
exception_mapping_worked = True
raise InvalidRequestError(f"CohereException - {error_str}", f"{model}")
elif "CohereConnectionError" in exception_type: # cohere seems to fire these errors when we load test it (1k+ messages / min)
exception_mapping_worked = True
raise RateLimitError(f"CohereException - {original_exception.message}")
raise original_exception # base case - return the original exception
else:
raise original_exception
except:
raise original_exception
except Exception as e:
## LOGGING
logging(logger_fn=user_logger_fn, additional_args={"original_exception": original_exception}, exception=e)
if exception_mapping_worked:
raise e
else: # don't let an error with mapping interrupt the user from receiving an error from the llm api calls
raise original_exception
def safe_crash_reporting(model=None, exception=None, azure=None):
data = {
@ -323,7 +394,6 @@ def litellm_telemetry(data):
'uuid': uuid_value,
'data': data
}
print_verbose(f"payload: {payload}")
try:
# Make the POST request to localhost:3000
response = requests.post('https://litellm.berri.ai/logging', json=payload)