Merge pull request #3962 from BerriAI/litellm_return_num_rets_max_exceptions

[Feat] return `num_retries` and `max_retries` in exceptions
This commit is contained in:
Ishaan Jaff 2024-06-01 17:48:38 -07:00 committed by GitHub
commit fb49d036fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 355 additions and 9 deletions

View file

@ -803,6 +803,7 @@ from .exceptions import (
APIConnectionError, APIConnectionError,
APIResponseValidationError, APIResponseValidationError,
UnprocessableEntityError, UnprocessableEntityError,
LITELLM_EXCEPTION_TYPES,
) )
from .budget_manager import BudgetManager from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server from .proxy.proxy_cli import run_server

View file

@ -22,16 +22,36 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
model, model,
response: httpx.Response, response: httpx.Response,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
): ):
self.status_code = 401 self.status_code = 401
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
# raise when invalid models passed, example gpt-8 # raise when invalid models passed, example gpt-8
class NotFoundError(openai.NotFoundError): # type: ignore class NotFoundError(openai.NotFoundError): # type: ignore
@ -42,16 +62,36 @@ class NotFoundError(openai.NotFoundError): # type: ignore
llm_provider, llm_provider,
response: httpx.Response, response: httpx.Response,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
): ):
self.status_code = 404 self.status_code = 404
self.message = message self.message = message
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class BadRequestError(openai.BadRequestError): # type: ignore class BadRequestError(openai.BadRequestError): # type: ignore
def __init__( def __init__(
@ -61,6 +101,8 @@ class BadRequestError(openai.BadRequestError): # type: ignore
llm_provider, llm_provider,
response: Optional[httpx.Response] = None, response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
): ):
self.status_code = 400 self.status_code = 400
self.message = message self.message = message
@ -73,10 +115,28 @@ class BadRequestError(openai.BadRequestError): # type: ignore
method="GET", url="https://litellm.ai" method="GET", url="https://litellm.ai"
), # mock request object ), # mock request object
) )
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
def __init__( def __init__(
@ -86,20 +146,46 @@ class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
llm_provider, llm_provider,
response: httpx.Response, response: httpx.Response,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
): ):
self.status_code = 422 self.status_code = 422
self.message = message self.message = message
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class Timeout(openai.APITimeoutError): # type: ignore class Timeout(openai.APITimeoutError): # type: ignore
def __init__( def __init__(
self, message, model, llm_provider, litellm_debug_info: Optional[str] = None self,
message,
model,
llm_provider,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
): ):
request = httpx.Request(method="POST", url="https://api.openai.com/v1") request = httpx.Request(method="POST", url="https://api.openai.com/v1")
super().__init__( super().__init__(
@ -110,10 +196,25 @@ class Timeout(openai.APITimeoutError): # type: ignore
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
# custom function to convert to str # custom function to convert to str
def __str__(self): def __str__(self):
return str(self.message) _message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
@ -124,16 +225,36 @@ class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
model, model,
response: httpx.Response, response: httpx.Response,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
): ):
self.status_code = 403 self.status_code = 403
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class RateLimitError(openai.RateLimitError): # type: ignore class RateLimitError(openai.RateLimitError): # type: ignore
def __init__( def __init__(
@ -143,16 +264,36 @@ class RateLimitError(openai.RateLimitError): # type: ignore
model, model,
response: httpx.Response, response: httpx.Response,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
): ):
self.status_code = 429 self.status_code = 429
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.modle = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors # sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
class ContextWindowExceededError(BadRequestError): # type: ignore class ContextWindowExceededError(BadRequestError): # type: ignore
@ -176,6 +317,22 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
response=response, response=response,
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
# sub class of bad request error - meant to help us catch guardrails-related errors on proxy. # sub class of bad request error - meant to help us catch guardrails-related errors on proxy.
class RejectedRequestError(BadRequestError): # type: ignore class RejectedRequestError(BadRequestError): # type: ignore
@ -202,6 +359,22 @@ class RejectedRequestError(BadRequestError): # type: ignore
response=response, response=response,
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class ContentPolicyViolationError(BadRequestError): # type: ignore class ContentPolicyViolationError(BadRequestError): # type: ignore
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}} # Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
@ -225,6 +398,22 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
response=response, response=response,
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class ServiceUnavailableError(openai.APIStatusError): # type: ignore class ServiceUnavailableError(openai.APIStatusError): # type: ignore
def __init__( def __init__(
@ -234,16 +423,36 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
model, model,
response: httpx.Response, response: httpx.Response,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
): ):
self.status_code = 503 self.status_code = 503
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401 # raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
class APIError(openai.APIError): # type: ignore class APIError(openai.APIError): # type: ignore
@ -255,14 +464,34 @@ class APIError(openai.APIError): # type: ignore
model, model,
request: httpx.Request, request: httpx.Request,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
): ):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__(self.message, request=request, body=None) # type: ignore super().__init__(self.message, request=request, body=None) # type: ignore
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made
class APIConnectionError(openai.APIConnectionError): # type: ignore class APIConnectionError(openai.APIConnectionError): # type: ignore
@ -273,19 +502,45 @@ class APIConnectionError(openai.APIConnectionError): # type: ignore
model, model,
request: httpx.Request, request: httpx.Request,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
): ):
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.status_code = 500 self.status_code = 500
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__(message=self.message, request=request) super().__init__(message=self.message, request=request)
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made
class APIResponseValidationError(openai.APIResponseValidationError): # type: ignore class APIResponseValidationError(openai.APIResponseValidationError): # type: ignore
def __init__( def __init__(
self, message, llm_provider, model, litellm_debug_info: Optional[str] = None self,
message,
llm_provider,
model,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
): ):
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
@ -293,8 +548,26 @@ class APIResponseValidationError(openai.APIResponseValidationError): # type: ig
request = httpx.Request(method="POST", url="https://api.openai.com/v1") request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=500, request=request) response = httpx.Response(status_code=500, request=request)
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__(response=response, body=None, message=message) super().__init__(response=response, body=None, message=message)
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class OpenAIError(openai.OpenAIError): # type: ignore class OpenAIError(openai.OpenAIError): # type: ignore
def __init__(self, original_exception): def __init__(self, original_exception):
@ -309,6 +582,25 @@ class OpenAIError(openai.OpenAIError): # type: ignore
self.llm_provider = "openai" self.llm_provider = "openai"
LITELLM_EXCEPTION_TYPES = [
AuthenticationError,
NotFoundError,
BadRequestError,
UnprocessableEntityError,
Timeout,
PermissionDeniedError,
RateLimitError,
ContextWindowExceededError,
RejectedRequestError,
ContentPolicyViolationError,
ServiceUnavailableError,
APIError,
APIConnectionError,
APIResponseValidationError,
OpenAIError,
]
class BudgetExceededError(Exception): class BudgetExceededError(Exception):
def __init__(self, current_cost, max_budget): def __init__(self, current_cost, max_budget):
self.current_cost = current_cost self.current_cost = current_cost

View file

@ -2059,6 +2059,8 @@ class Router:
response = await original_function(*args, **kwargs) response = await original_function(*args, **kwargs)
return response return response
except Exception as e: except Exception as e:
num_retries = None
current_attempt = None
original_exception = e original_exception = e
""" """
Retry Logic Retry Logic
@ -2128,11 +2130,10 @@ class Router:
) )
await asyncio.sleep(_timeout) await asyncio.sleep(_timeout)
try: if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
cooldown_deployments = await self._async_get_cooldown_deployments() original_exception.max_retries = num_retries
original_exception.message += f"\nNumber Retries = {current_attempt + 1}, Max Retries={num_retries}\nCooldown Deployments={cooldown_deployments}" original_exception.num_retries = current_attempt
except:
pass
raise original_exception raise original_exception
def should_retry_this_error( def should_retry_this_error(
@ -2333,6 +2334,8 @@ class Router:
response = original_function(*args, **kwargs) response = original_function(*args, **kwargs)
return response return response
except Exception as e: except Exception as e:
num_retries = None
current_attempt = None
original_exception = e original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
_healthy_deployments = self._get_healthy_deployments( _healthy_deployments = self._get_healthy_deployments(
@ -2383,6 +2386,11 @@ class Router:
healthy_deployments=_healthy_deployments, healthy_deployments=_healthy_deployments,
) )
time.sleep(_timeout) time.sleep(_timeout)
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
original_exception.max_retries = num_retries
original_exception.num_retries = current_attempt
raise original_exception raise original_exception
### HELPER FUNCTIONS ### HELPER FUNCTIONS

View file

@ -186,6 +186,51 @@ async def test_router_retry_policy(error_type):
assert customHandler.previous_models == 3 assert customHandler.previous_models == 3
@pytest.mark.asyncio
@pytest.mark.skip(
reason="This is a local only test, use this to confirm if retry policy works"
)
async def test_router_retry_policy_on_429_errprs():
from litellm.router import RetryPolicy
retry_policy = RetryPolicy(
RateLimitErrorRetries=2,
)
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": {
"model": "vertex_ai/gemini-1.5-pro-001",
},
},
],
retry_policy=retry_policy,
# set_verbose=True,
# debug_level="DEBUG",
allowed_fails=10,
)
customHandler = MyCustomHandler()
litellm.callbacks = [customHandler]
try:
# litellm.set_verbose = True
_one_message = [{"role": "user", "content": "Hello good morning"}]
messages = [_one_message] * 5
print("messages: ", messages)
responses = await router.abatch_completion(
models=["gpt-3.5-turbo"],
messages=messages,
)
print("responses: ", responses)
except Exception as e:
print("got an exception", e)
pass
asyncio.sleep(0.05)
print("customHandler.previous_models: ", customHandler.previous_models)
@pytest.mark.parametrize("model_group", ["gpt-3.5-turbo", "bad-model"]) @pytest.mark.parametrize("model_group", ["gpt-3.5-turbo", "bad-model"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dynamic_router_retry_policy(model_group): async def test_dynamic_router_retry_policy(model_group):