return llm_provider as part of the exception

This commit is contained in:
Krrish Dholakia 2023-08-17 11:31:34 -07:00
parent 2cfd4dd871
commit b91c69ffde
6 changed files with 75 additions and 31 deletions

43
litellm/exceptions.py Normal file
View file

@ -0,0 +1,43 @@
## LiteLLM versions of the OpenAI Exception Types
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, ServiceUnavailableError, OpenAIError
class AuthenticationError(AuthenticationError):
def __init__(self, message, llm_provider):
self.status_code = 401
self.message = message
self.llm_provider = llm_provider
super().__init__(self.message) # Call the base class constructor with the parameters it needs
class InvalidRequestError(InvalidRequestError):
def __init__(self, message, model, llm_provider):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(self.message, f"{self.model}") # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError):
def __init__(self, message, llm_provider):
self.status_code = 429
self.message = message
self.llm_provider = llm_provider
super().__init__(self.message) # Call the base class constructor with the parameters it needs
class ServiceUnavailableError(ServiceUnavailableError):
def __init__(self, message, llm_provider):
self.status_code = 500
self.message = message
self.llm_provider = llm_provider
super().__init__(self.message) # Call the base class constructor with the parameters it needs
class OpenAIError(OpenAIError):
def __init__(self, original_exception):
self.status_code = original_exception.http_status
super().__init__(http_body=original_exception.http_body,
http_status=original_exception.http_status,
json_body=original_exception.json_body,
headers=original_exception.headers,
code=original_exception.code)
self.llm_provider = "openai"

View file

@ -40,11 +40,11 @@ def test_context_window(model):
print(f"model: {model}")
response = completion(model=model, messages=messages, custom_llm_provider="azure", logger_fn=logging_fn)
print(f"response: {response}")
except InvalidRequestError:
print("InvalidRequestError")
except InvalidRequestError as e:
print(f"InvalidRequestError: {e.llm_provider}")
return
except OpenAIError:
print("OpenAIError")
except OpenAIError as e:
print(f"OpenAIError: {e.llm_provider}")
return
except Exception as e:
print("Uncaught Error in test_context_window")
@ -81,7 +81,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on the
response = completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
print(f"response: {response}")
except AuthenticationError as e:
print(f"AuthenticationError Caught Exception - {e}")
print(f"AuthenticationError Caught Exception - {e.llm_provider}")
except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
print(f"OpenAIError Caught Exception - {e}")
except Exception as e:

View file

@ -5,13 +5,13 @@ import litellm, openai
import random, uuid, requests
import datetime, time
import tiktoken
from importlib_metadata import DistributionNotFound, VersionConflict
encoding = tiktoken.get_encoding("cl100k_base")
from .integrations.helicone import HeliconeLogger
from .integrations.aispend import AISpendLogger
from .integrations.berrispend import BerriSpendLogger
from .integrations.supabase import Supabase
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, ServiceUnavailableError, OpenAIError
from openai.error import OpenAIError as OriginalError
from .exceptions import AuthenticationError, InvalidRequestError, RateLimitError, ServiceUnavailableError, OpenAIError
####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
sentry_sdk_instance = None
@ -102,18 +102,14 @@ def install_and_import(package: str):
try:
# Import the module
module = importlib.import_module(package)
except (ModuleNotFoundError, ImportError):
except ImportError:
print_verbose(f"{package} is not installed. Installing...")
subprocess.call([sys.executable, "-m", "pip", "install", package])
globals()[package] = importlib.import_module(package)
except (DistributionNotFound, ImportError):
print_verbose(f"{package} is not installed. Installing...")
subprocess.call([sys.executable, "-m", "pip", "install", package])
globals()[package] = importlib.import_module(package)
except VersionConflict as vc:
print_verbose(f"Detected version conflict for {package}. Upgrading...")
subprocess.call([sys.executable, "-m", "pip", "install", "--upgrade", package])
globals()[package] = importlib.import_module(package)
# except VersionConflict as vc:
# print_verbose(f"Detected version conflict for {package}. Upgrading...")
# subprocess.call([sys.executable, "-m", "pip", "install", "--upgrade", package])
# globals()[package] = importlib.import_module(package)
finally:
if package not in globals().keys():
globals()[package] = importlib.import_module(package)
@ -687,8 +683,13 @@ def exception_type(model, original_exception, custom_llm_provider):
global user_logger_fn
exception_mapping_worked = False
try:
if isinstance(original_exception, OpenAIError):
if isinstance(original_exception, OriginalError):
# Handle the OpenAIError
exception_mapping_worked = True
if custom_llm_provider == "azure":
original_exception.llm_provider = "azure"
else:
original_exception.llm_provider = "openai"
raise original_exception
elif model:
error_str = str(original_exception)
@ -702,49 +703,49 @@ def exception_type(model, original_exception, custom_llm_provider):
print_verbose(f"status_code: {original_exception.status_code}")
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(f"AnthropicException - {original_exception.message}")
raise AuthenticationError(message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic")
elif original_exception.status_code == 400:
exception_mapping_worked = True
raise InvalidRequestError(f"AnthropicException - {original_exception.message}", f"{model}")
raise InvalidRequestError(message=f"AnthropicException - {original_exception.message}", model=model, llm_provider="anthropic")
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(f"AnthropicException - {original_exception.message}")
raise RateLimitError(message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic")
elif "Could not resolve authentication method. Expected either api_key or auth_token to be set." in error_str:
exception_mapping_worked = True
raise AuthenticationError(f"AnthropicException - {error_str}")
raise AuthenticationError(message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic")
elif "replicate" in model:
if "Incorrect authentication token" in error_str:
exception_mapping_worked = True
raise AuthenticationError(f"ReplicateException - {error_str}")
raise AuthenticationError(message=f"ReplicateException - {error_str}", llm_provider="replicate")
elif exception_type == "ModelError":
exception_mapping_worked = True
raise InvalidRequestError(f"ReplicateException - {error_str}", f"{model}")
raise InvalidRequestError(message=f"ReplicateException - {error_str}", model=model, llm_provider="replicate")
elif "Request was throttled" in error_str:
exception_mapping_worked = True
raise RateLimitError(f"ReplicateException - {error_str}")
raise RateLimitError(message=f"ReplicateException - {error_str}", llm_provider="replicate")
elif exception_type == "ReplicateError": ## ReplicateError implies an error on Replicate server side, not user side
raise ServiceUnavailableError(f"ReplicateException - {error_str}")
raise ServiceUnavailableError(message=f"ReplicateException - {error_str}", llm_provider="replicate")
elif model == "command-nightly": #Cohere
if "invalid api token" in error_str or "No API key provided." in error_str:
exception_mapping_worked = True
raise AuthenticationError(f"CohereException - {error_str}")
raise AuthenticationError(message=f"CohereException - {original_exception.message}", llm_provider="cohere")
elif "too many tokens" in error_str:
exception_mapping_worked = True
raise InvalidRequestError(f"CohereException - {error_str}", f"{model}")
raise InvalidRequestError(message=f"CohereException - {original_exception.message}", model=model, llm_provider="cohere")
elif "CohereConnectionError" in exception_type: # cohere seems to fire these errors when we load test it (1k+ messages / min)
exception_mapping_worked = True
raise RateLimitError(f"CohereException - {original_exception.message}")
raise RateLimitError(message=f"CohereException - {original_exception.message}", llm_provider="cohere")
elif custom_llm_provider == "huggingface":
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(f"HuggingfaceException - {original_exception.message}")
raise AuthenticationError(message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface")
elif original_exception.status_code == 400:
exception_mapping_worked = True
raise InvalidRequestError(f"HuggingfaceException - {original_exception.message}", f"{model}")
raise InvalidRequestError(message=f"HuggingfaceException - {original_exception.message}", model=model, llm_provider="huggingface")
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(f"HuggingfaceException - {original_exception.message}")
raise RateLimitError(message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface")
raise original_exception # base case - return the original exception
else:
raise original_exception