return llm_provider as part of the exception

This commit is contained in:
Krrish Dholakia 2023-08-17 11:31:34 -07:00
parent 2cfd4dd871
commit b91c69ffde
6 changed files with 75 additions and 31 deletions

43
litellm/exceptions.py Normal file
View file

@ -0,0 +1,43 @@
## LiteLLM versions of the OpenAI Exception Types
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, ServiceUnavailableError, OpenAIError
class AuthenticationError(AuthenticationError):
def __init__(self, message, llm_provider):
self.status_code = 401
self.message = message
self.llm_provider = llm_provider
super().__init__(self.message) # Call the base class constructor with the parameters it needs
class InvalidRequestError(InvalidRequestError):
def __init__(self, message, model, llm_provider):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(self.message, f"{self.model}") # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError):
def __init__(self, message, llm_provider):
self.status_code = 429
self.message = message
self.llm_provider = llm_provider
super().__init__(self.message) # Call the base class constructor with the parameters it needs
class ServiceUnavailableError(ServiceUnavailableError):
def __init__(self, message, llm_provider):
self.status_code = 500
self.message = message
self.llm_provider = llm_provider
super().__init__(self.message) # Call the base class constructor with the parameters it needs
class OpenAIError(OpenAIError):
def __init__(self, original_exception):
self.status_code = original_exception.http_status
super().__init__(http_body=original_exception.http_body,
http_status=original_exception.http_status,
json_body=original_exception.json_body,
headers=original_exception.headers,
code=original_exception.code)
self.llm_provider = "openai"

View file

@ -40,11 +40,11 @@ def test_context_window(model):
print(f"model: {model}") print(f"model: {model}")
response = completion(model=model, messages=messages, custom_llm_provider="azure", logger_fn=logging_fn) response = completion(model=model, messages=messages, custom_llm_provider="azure", logger_fn=logging_fn)
print(f"response: {response}") print(f"response: {response}")
except InvalidRequestError: except InvalidRequestError as e:
print("InvalidRequestError") print(f"InvalidRequestError: {e.llm_provider}")
return return
except OpenAIError: except OpenAIError as e:
print("OpenAIError") print(f"OpenAIError: {e.llm_provider}")
return return
except Exception as e: except Exception as e:
print("Uncaught Error in test_context_window") print("Uncaught Error in test_context_window")
@ -81,7 +81,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on the
response = completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider) response = completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
print(f"response: {response}") print(f"response: {response}")
except AuthenticationError as e: except AuthenticationError as e:
print(f"AuthenticationError Caught Exception - {e}") print(f"AuthenticationError Caught Exception - {e.llm_provider}")
except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
print(f"OpenAIError Caught Exception - {e}") print(f"OpenAIError Caught Exception - {e}")
except Exception as e: except Exception as e:

View file

@ -5,13 +5,13 @@ import litellm, openai
import random, uuid, requests import random, uuid, requests
import datetime, time import datetime, time
import tiktoken import tiktoken
from importlib_metadata import DistributionNotFound, VersionConflict
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
from .integrations.helicone import HeliconeLogger from .integrations.helicone import HeliconeLogger
from .integrations.aispend import AISpendLogger from .integrations.aispend import AISpendLogger
from .integrations.berrispend import BerriSpendLogger from .integrations.berrispend import BerriSpendLogger
from .integrations.supabase import Supabase from .integrations.supabase import Supabase
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, ServiceUnavailableError, OpenAIError from openai.error import OpenAIError as OriginalError
from .exceptions import AuthenticationError, InvalidRequestError, RateLimitError, ServiceUnavailableError, OpenAIError
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
sentry_sdk_instance = None sentry_sdk_instance = None
@ -102,18 +102,14 @@ def install_and_import(package: str):
try: try:
# Import the module # Import the module
module = importlib.import_module(package) module = importlib.import_module(package)
except (ModuleNotFoundError, ImportError): except ImportError:
print_verbose(f"{package} is not installed. Installing...") print_verbose(f"{package} is not installed. Installing...")
subprocess.call([sys.executable, "-m", "pip", "install", package]) subprocess.call([sys.executable, "-m", "pip", "install", package])
globals()[package] = importlib.import_module(package) globals()[package] = importlib.import_module(package)
except (DistributionNotFound, ImportError): # except VersionConflict as vc:
print_verbose(f"{package} is not installed. Installing...") # print_verbose(f"Detected version conflict for {package}. Upgrading...")
subprocess.call([sys.executable, "-m", "pip", "install", package]) # subprocess.call([sys.executable, "-m", "pip", "install", "--upgrade", package])
globals()[package] = importlib.import_module(package) # globals()[package] = importlib.import_module(package)
except VersionConflict as vc:
print_verbose(f"Detected version conflict for {package}. Upgrading...")
subprocess.call([sys.executable, "-m", "pip", "install", "--upgrade", package])
globals()[package] = importlib.import_module(package)
finally: finally:
if package not in globals().keys(): if package not in globals().keys():
globals()[package] = importlib.import_module(package) globals()[package] = importlib.import_module(package)
@ -687,8 +683,13 @@ def exception_type(model, original_exception, custom_llm_provider):
global user_logger_fn global user_logger_fn
exception_mapping_worked = False exception_mapping_worked = False
try: try:
if isinstance(original_exception, OpenAIError): if isinstance(original_exception, OriginalError):
# Handle the OpenAIError # Handle the OpenAIError
exception_mapping_worked = True
if custom_llm_provider == "azure":
original_exception.llm_provider = "azure"
else:
original_exception.llm_provider = "openai"
raise original_exception raise original_exception
elif model: elif model:
error_str = str(original_exception) error_str = str(original_exception)
@ -702,49 +703,49 @@ def exception_type(model, original_exception, custom_llm_provider):
print_verbose(f"status_code: {original_exception.status_code}") print_verbose(f"status_code: {original_exception.status_code}")
if original_exception.status_code == 401: if original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError(f"AnthropicException - {original_exception.message}") raise AuthenticationError(message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic")
elif original_exception.status_code == 400: elif original_exception.status_code == 400:
exception_mapping_worked = True exception_mapping_worked = True
raise InvalidRequestError(f"AnthropicException - {original_exception.message}", f"{model}") raise InvalidRequestError(message=f"AnthropicException - {original_exception.message}", model=model, llm_provider="anthropic")
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError(f"AnthropicException - {original_exception.message}") raise RateLimitError(message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic")
elif "Could not resolve authentication method. Expected either api_key or auth_token to be set." in error_str: elif "Could not resolve authentication method. Expected either api_key or auth_token to be set." in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError(f"AnthropicException - {error_str}") raise AuthenticationError(message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic")
elif "replicate" in model: elif "replicate" in model:
if "Incorrect authentication token" in error_str: if "Incorrect authentication token" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError(f"ReplicateException - {error_str}") raise AuthenticationError(message=f"ReplicateException - {error_str}", llm_provider="replicate")
elif exception_type == "ModelError": elif exception_type == "ModelError":
exception_mapping_worked = True exception_mapping_worked = True
raise InvalidRequestError(f"ReplicateException - {error_str}", f"{model}") raise InvalidRequestError(message=f"ReplicateException - {error_str}", model=model, llm_provider="replicate")
elif "Request was throttled" in error_str: elif "Request was throttled" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError(f"ReplicateException - {error_str}") raise RateLimitError(message=f"ReplicateException - {error_str}", llm_provider="replicate")
elif exception_type == "ReplicateError": ## ReplicateError implies an error on Replicate server side, not user side elif exception_type == "ReplicateError": ## ReplicateError implies an error on Replicate server side, not user side
raise ServiceUnavailableError(f"ReplicateException - {error_str}") raise ServiceUnavailableError(message=f"ReplicateException - {error_str}", llm_provider="replicate")
elif model == "command-nightly": #Cohere elif model == "command-nightly": #Cohere
if "invalid api token" in error_str or "No API key provided." in error_str: if "invalid api token" in error_str or "No API key provided." in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError(f"CohereException - {error_str}") raise AuthenticationError(message=f"CohereException - {original_exception.message}", llm_provider="cohere")
elif "too many tokens" in error_str: elif "too many tokens" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise InvalidRequestError(f"CohereException - {error_str}", f"{model}") raise InvalidRequestError(message=f"CohereException - {original_exception.message}", model=model, llm_provider="cohere")
elif "CohereConnectionError" in exception_type: # cohere seems to fire these errors when we load test it (1k+ messages / min) elif "CohereConnectionError" in exception_type: # cohere seems to fire these errors when we load test it (1k+ messages / min)
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError(f"CohereException - {original_exception.message}") raise RateLimitError(message=f"CohereException - {original_exception.message}", llm_provider="cohere")
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
if hasattr(original_exception, "status_code"): if hasattr(original_exception, "status_code"):
if original_exception.status_code == 401: if original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError(f"HuggingfaceException - {original_exception.message}") raise AuthenticationError(message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface")
elif original_exception.status_code == 400: elif original_exception.status_code == 400:
exception_mapping_worked = True exception_mapping_worked = True
raise InvalidRequestError(f"HuggingfaceException - {original_exception.message}", f"{model}") raise InvalidRequestError(message=f"HuggingfaceException - {original_exception.message}", model=model, llm_provider="huggingface")
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError(f"HuggingfaceException - {original_exception.message}") raise RateLimitError(message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface")
raise original_exception # base case - return the original exception raise original_exception # base case - return the original exception
else: else:
raise original_exception raise original_exception