fixing exception mapping

This commit is contained in:
Krrish Dholakia 2023-08-05 09:52:01 -07:00
parent 9b0e9bf57c
commit 92a13958ce
8 changed files with 188 additions and 115 deletions

View file

@ -69,6 +69,6 @@ open_ai_embedding_models = [
'text-embedding-ada-002'
]
from .timeout import timeout
from .utils import client, logging, exception_type # Import all the symbols from main.py
from .utils import client, logging, exception_type, get_optional_params # Import all the symbols from main.py
from .main import * # Import all the symbols from main.py
from .integrations import *

View file

@ -6,7 +6,7 @@ from functools import partial
import dotenv
import traceback
import litellm
from litellm import client, logging, exception_type, timeout
from litellm import client, logging, exception_type, timeout, get_optional_params
import random
import asyncio
from tenacity import (
@ -20,51 +20,6 @@ dotenv.load_dotenv() # Loading env variables using dotenv
# TODO move this to utils.py
# TODO add translations
# TODO see if this worked - model_name == krrish
def get_optional_params(
# 12 optional params
functions = [],
function_call = "",
temperature = 1,
top_p = 1,
n = 1,
stream = False,
stop = None,
max_tokens = float('inf'),
presence_penalty = 0,
frequency_penalty = 0,
logit_bias = {},
user = "",
deployment_id = None
):
optional_params = {}
if functions != []:
optional_params["functions"] = functions
if function_call != "":
optional_params["function_call"] = function_call
if temperature != 1:
optional_params["temperature"] = temperature
if top_p != 1:
optional_params["top_p"] = top_p
if n != 1:
optional_params["n"] = n
if stream:
optional_params["stream"] = stream
if stop != None:
optional_params["stop"] = stop
if max_tokens != float('inf'):
optional_params["max_tokens"] = max_tokens
if presence_penalty != 0:
optional_params["presence_penalty"] = presence_penalty
if frequency_penalty != 0:
optional_params["frequency_penalty"] = frequency_penalty
if logit_bias != {}:
optional_params["logit_bias"] = logit_bias
if user != "":
optional_params["user"] = user
if deployment_id != None:
optional_params["deployment_id"] = user
return optional_params
####### COMPLETION ENDPOINTS ################
#############################################
async def acompletion(*args, **kwargs):
@ -285,12 +240,13 @@ def completion(
}
response = new_response
else:
## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
args = locals()
raise ValueError(f"No valid completion model args passed in - {args}")
return response
except Exception as e:
# log the original exception
## LOGGING
logging(model=model, input=messages, azure=azure, additional_args={"max_tokens": max_tokens}, logger_fn=logger_fn, exception=e)
## Map to OpenAI Exception
raise exception_type(model=model, original_exception=e)

View file

@ -8,6 +8,7 @@ from litellm import embedding, completion
from concurrent.futures import ThreadPoolExecutor
import pytest
litellm.failure_callback = ["sentry"]
# litellm.set_verbose = True
#### What this tests ####
# This tests exception mapping -> trigger an exception from an llm provider -> assert if output is of the expected type
@ -22,11 +23,16 @@ import pytest
# models = ["gpt-3.5-turbo", "chatgpt-test", "claude-instant-1", "command-nightly"]
models = ["command-nightly"]
def logging_fn(model_call_dict):
print(f"model_call_dict: {model_call_dict['model']}")
if "model" in model_call_dict:
print(f"model_call_dict: {model_call_dict['model']}")
else:
print(f"model_call_dict: {model_call_dict}")
# Test 1: Context Window Errors
@pytest.mark.parametrize("model", models)
def test_context_window(model):
sample_text = "how does a court case get to the Supreme Court?" * 100000
sample_text = "how does a court case get to the Supreme Court?" * 5000
messages = [{"content": sample_text, "role": "user"}]
try:
azure = model == "chatgpt-test"
@ -41,44 +47,61 @@ def test_context_window(model):
return
except Exception as e:
print("Uncaught Error in test_context_window")
# print(f"Error Type: {type(e).__name__}")
print(f"Error Type: {type(e).__name__}")
print(f"Uncaught Exception - {e}")
pytest.fail(f"Error occurred: {e}")
return
test_context_window("command-nightly")
# # Test 2: InvalidAuth Errors
# def logger_fn(model_call_object: dict):
# print(f"model call details: {model_call_object}")
# @pytest.mark.parametrize("model", models)
# def invalid_auth(model): # set the model key to an invalid key, depending on the model
# messages = [{ "content": "Hello, how are you?","role": "user"}]
# try:
# azure = False
# if model == "gpt-3.5-turbo":
# os.environ["OPENAI_API_KEY"] = "bad-key"
# elif model == "chatgpt-test":
# os.environ["AZURE_API_KEY"] = "bad-key"
# azure = True
# elif model == "claude-instant-1":
# os.environ["ANTHROPIC_API_KEY"] = "bad-key"
# elif model == "command-nightly":
# os.environ["COHERE_API_KEY"] = "bad-key"
# elif model == "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1":
# os.environ["REPLICATE_API_KEY"] = "bad-key"
# os.environ["REPLICATE_API_TOKEN"] = "bad-key"
# print(f"model: {model}")
# response = completion(model=model, messages=messages, azure=azure)
# print(f"response: {response}")
# except AuthenticationError as e:
# return
# except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
# return
# except Exception as e:
# print(f"Uncaught Exception - {e}")
# pytest.fail(f"Error occurred: {e}")
# return
# Test 2: InvalidAuth Errors
@pytest.mark.parametrize("model", models)
def invalid_auth(model): # set the model key to an invalid key, depending on the model
messages = [{ "content": "Hello, how are you?","role": "user"}]
temporary_key = None
try:
azure = False
if model == "gpt-3.5-turbo":
temporary_key = os.environ["OPENAI_API_KEY"]
os.environ["OPENAI_API_KEY"] = "bad-key"
elif model == "chatgpt-test":
temporary_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "bad-key"
azure = True
elif model == "claude-instant-1":
temporary_key = os.environ["ANTHROPIC_API_KEY"]
os.environ["ANTHROPIC_API_KEY"] = "bad-key"
elif model == "command-nightly":
temporary_key = os.environ["COHERE_API_KEY"]
os.environ["COHERE_API_KEY"] = "bad-key"
elif model == "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1":
temporary_key = os.environ["REPLICATE_API_KEY"]
os.environ["REPLICATE_API_KEY"] = "bad-key"
print(f"model: {model}")
response = completion(model=model, messages=messages, azure=azure)
print(f"response: {response}")
except AuthenticationError as e:
print(f"AuthenticationError Caught Exception - {e}")
except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
print(f"OpenAIError Caught Exception - {e}")
except Exception as e:
print(type(e))
print(e.__class__.__name__)
print(f"Uncaught Exception - {e}")
pytest.fail(f"Error occurred: {e}")
if temporary_key != None: # reset the key
if model == "gpt-3.5-turbo":
os.environ["OPENAI_API_KEY"] = temporary_key
elif model == "chatgpt-test":
os.environ["AZURE_API_KEY"] = temporary_key
azure = True
elif model == "claude-instant-1":
os.environ["ANTHROPIC_API_KEY"] = temporary_key
elif model == "command-nightly":
os.environ["COHERE_API_KEY"] = temporary_key
elif model == "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1":
os.environ["REPLICATE_API_KEY"] = temporary_key
return
invalid_auth("command-nightly")
# # Test 3: Rate Limit Errors
# def test_model(model):
# try:

View file

@ -25,41 +25,44 @@ def print_verbose(print_statement):
####### LOGGING ###################
#Logging function -> log the exact model details + what's being sent | Non-Blocking
def logging(model, input, azure=False, additional_args={}, logger_fn=None, exception=None):
def logging(model=None, input=None, azure=False, additional_args={}, logger_fn=None, exception=None):
try:
model_call_details = {}
model_call_details["model"] = model
model_call_details["azure"] = azure
# log exception details
if model:
model_call_details["model"] = model
if azure:
model_call_details["azure"] = azure
if exception:
model_call_details["original_exception"] = exception
if litellm.telemetry:
safe_crash_reporting(model=model, exception=exception, azure=azure) # log usage-crash details. Do not log any user details. If you want to turn this off, set `litellm.telemetry=False`.
model_call_details["input"] = input
if input:
model_call_details["input"] = input
# log additional call details -> api key, etc.
if azure == True or model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_embedding_models:
model_call_details["api_type"] = openai.api_type
model_call_details["api_base"] = openai.api_base
model_call_details["api_version"] = openai.api_version
model_call_details["api_key"] = openai.api_key
elif "replicate" in model:
model_call_details["api_key"] = os.environ.get("REPLICATE_API_TOKEN")
elif model in litellm.anthropic_models:
model_call_details["api_key"] = os.environ.get("ANTHROPIC_API_KEY")
elif model in litellm.cohere_models:
model_call_details["api_key"] = os.environ.get("COHERE_API_KEY")
model_call_details["additional_args"] = additional_args
if model:
if azure == True or model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_embedding_models:
model_call_details["api_type"] = openai.api_type
model_call_details["api_base"] = openai.api_base
model_call_details["api_version"] = openai.api_version
model_call_details["api_key"] = openai.api_key
elif "replicate" in model:
model_call_details["api_key"] = os.environ.get("REPLICATE_API_TOKEN")
elif model in litellm.anthropic_models:
model_call_details["api_key"] = os.environ.get("ANTHROPIC_API_KEY")
elif model in litellm.cohere_models:
model_call_details["api_key"] = os.environ.get("COHERE_API_KEY")
model_call_details["additional_args"] = additional_args
## User Logging -> if you pass in a custom logging function or want to use sentry breadcrumbs
print_verbose(f"Basic model call details: {model_call_details}")
print_verbose(f"Logging Details: logger_fn - {logger_fn} | callable(logger_fn) - {callable(logger_fn)}")
if logger_fn and callable(logger_fn):
try:
logger_fn(model_call_details) # Expectation: any logger function passed in by the user should accept a dict object
except:
print_verbose(f"[Non-Blocking] Exception occurred while logging {traceback.format_exc()}")
except:
traceback.print_exc()
except Exception as e:
print(f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}")
except Exception as e:
print(f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}")
pass
####### CLIENT ###################
@ -67,7 +70,7 @@ def logging(model, input, azure=False, additional_args={}, logger_fn=None, excep
def client(original_function):
def function_setup(*args, **kwargs): #just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
try:
global callback_list, add_breadcrumb
global callback_list, add_breadcrumb, user_logger_fn
if (len(litellm.success_callback) > 0 or len(litellm.failure_callback) > 0) and len(callback_list) == 0:
callback_list = list(set(litellm.success_callback + litellm.failure_callback))
set_callbacks(callback_list=callback_list,)
@ -77,13 +80,15 @@ def client(original_function):
message=f"Positional Args: {args}, Keyword Args: {kwargs}",
level="info",
)
if "logger_fn" in kwargs:
user_logger_fn = kwargs["logger_fn"]
except: # DO NOT BLOCK running the function because of this
print_verbose(f"[Non-Blocking] {traceback.format_exc()}")
pass
def wrapper(*args, **kwargs):
try:
function_setup(args, kwargs)
function_setup(*args, **kwargs)
## MODEL CALL
start_time = datetime.datetime.now()
result = original_function(*args, **kwargs)
@ -100,6 +105,51 @@ def client(original_function):
return wrapper
####### HELPER FUNCTIONS ################
def get_optional_params(
# 12 optional params
functions = [],
function_call = "",
temperature = 1,
top_p = 1,
n = 1,
stream = False,
stop = None,
max_tokens = float('inf'),
presence_penalty = 0,
frequency_penalty = 0,
logit_bias = {},
user = "",
deployment_id = None
):
optional_params = {}
if functions != []:
optional_params["functions"] = functions
if function_call != "":
optional_params["function_call"] = function_call
if temperature != 1:
optional_params["temperature"] = temperature
if top_p != 1:
optional_params["top_p"] = top_p
if n != 1:
optional_params["n"] = n
if stream:
optional_params["stream"] = stream
if stop != None:
optional_params["stop"] = stop
if max_tokens != float('inf'):
optional_params["max_tokens"] = max_tokens
if presence_penalty != 0:
optional_params["presence_penalty"] = presence_penalty
if frequency_penalty != 0:
optional_params["frequency_penalty"] = frequency_penalty
if logit_bias != {}:
optional_params["logit_bias"] = logit_bias
if user != "":
optional_params["user"] = user
if deployment_id != None:
optional_params["deployment_id"] = user
return optional_params
def set_callbacks(callback_list):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger
try:
@ -150,8 +200,8 @@ def set_callbacks(callback_list):
def handle_failure(exception, traceback_exception, args, kwargs):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel
try:
print_verbose(f"handle_failure args: {args}")
print_verbose(f"handle_failure kwargs: {kwargs}")
# print_verbose(f"handle_failure args: {args}")
# print_verbose(f"handle_failure kwargs: {kwargs}")
success_handler = additional_details.pop("success_handler", None)
failure_handler = additional_details.pop("failure_handler", None)
@ -159,7 +209,8 @@ def handle_failure(exception, traceback_exception, args, kwargs):
additional_details["Event_Name"] = additional_details.pop("failed_event_name", "litellm.failed_query")
print_verbose(f"self.failure_callback: {litellm.failure_callback}")
print_verbose(f"additional_details: {additional_details}")
# print_verbose(f"additional_details: {additional_details}")
for callback in litellm.failure_callback:
try:
if callback == "slack":
@ -206,7 +257,9 @@ def handle_failure(exception, traceback_exception, args, kwargs):
}
failure_handler(call_details)
pass
except:
except Exception as e:
## LOGGING
logging(logger_fn=user_logger_fn, exception=e)
pass
def handle_success(args, kwargs, result, start_time, end_time):
@ -245,12 +298,16 @@ def handle_success(args, kwargs, result, start_time, end_time):
if success_handler and callable(success_handler):
success_handler(args, kwargs)
pass
except:
except Exception as e:
## LOGGING
logging(logger_fn=user_logger_fn, exception=e)
print_verbose(f"Success Callback Error - {traceback.format_exc()}")
pass
def exception_type(model, original_exception):
global user_logger_fn
exception_mapping_worked = False
try:
if isinstance(original_exception, OpenAIError):
# Handle the OpenAIError
@ -265,32 +322,46 @@ def exception_type(model, original_exception):
if "status_code" in original_exception:
print_verbose(f"status_code: {original_exception.status_code}")
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(f"AnthropicException - {original_exception.message}")
elif original_exception.status_code == 400:
exception_mapping_worked = True
raise InvalidRequestError(f"AnthropicException - {original_exception.message}", f"{model}")
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(f"AnthropicException - {original_exception.message}")
elif "replicate" in model:
if "Incorrect authentication token" in error_str:
exception_mapping_worked = True
raise AuthenticationError(f"ReplicateException - {error_str}")
elif exception_type == "ModelError":
exception_mapping_worked = True
raise InvalidRequestError(f"ReplicateException - {error_str}", f"{model}")
elif "Request was throttled" in error_str:
exception_mapping_worked = True
raise RateLimitError(f"ReplicateException - {error_str}")
elif exception_type == "ReplicateError": ## ReplicateError implies an error on Replicate server side, not user side
raise ServiceUnavailableError(f"ReplicateException - {error_str}")
elif model == "command-nightly": #Cohere
if "invalid api token" in error_str or "No API key provided." in error_str:
exception_mapping_worked = True
raise AuthenticationError(f"CohereException - {error_str}")
elif "too many tokens" in error_str:
exception_mapping_worked = True
raise InvalidRequestError(f"CohereException - {error_str}", f"{model}")
elif "CohereConnectionError" in exception_type: # cohere seems to fire these errors when we load test it (1k+ messages / min)
exception_mapping_worked = True
raise RateLimitError(f"CohereException - {original_exception.message}")
raise original_exception # base case - return the original exception
else:
raise original_exception
except:
raise original_exception
except Exception as e:
## LOGGING
logging(logger_fn=user_logger_fn, additional_args={"original_exception": original_exception}, exception=e)
if exception_mapping_worked:
raise e
else: # don't let an error with mapping interrupt the user from receiving an error from the llm api calls
raise original_exception
def safe_crash_reporting(model=None, exception=None, azure=None):
data = {
@ -323,7 +394,6 @@ def litellm_telemetry(data):
'uuid': uuid_value,
'data': data
}
print_verbose(f"payload: {payload}")
try:
# Make the POST request to localhost:3000
response = requests.post('https://litellm.berri.ai/logging', json=payload)

24
setup.py Normal file
View file

@ -0,0 +1,24 @@
from setuptools import setup, find_packages
setup(
name='litellm',
version='0.1.231',
description='Library to easily interface with LLM API providers',
author='BerriAI',
packages=[
'litellm'
],
package_data={
"litellm": ["integrations/*"], # Specify the directory path relative to your package
},
install_requires=[
'openai',
'cohere',
'pytest',
'anthropic',
'replicate',
'python-dotenv',
'openai[datalib]',
'tenacity'
],
)