Merge pull request #3457 from BerriAI/litellm_return_num_retries_exceptions

[Feat] return num_retries in litellm.Router exceptions
This commit is contained in:
Ishaan Jaff 2024-05-04 20:41:54 -07:00 committed by GitHub
commit ba065653ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 62 additions and 63 deletions

View file

@ -1544,6 +1544,10 @@ class Router:
num_retries=num_retries,
)
await asyncio.sleep(_timeout)
try:
original_exception.message += f"\nNumber Retries = {current_attempt}"
except:
pass
raise original_exception
def function_with_fallbacks(self, *args, **kwargs):

View file

@ -82,7 +82,7 @@ def test_async_fallbacks(caplog):
# Define the expected log messages
# - error request, falling back notice, success notice
expected_logs = [
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m",
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}} \nModel: gpt-3.5-turbo\nAPI Base: https://api.openai.com\nMessages: [{'content': 'Hello, how are you?', 'role': 'user'}]\nmodel_group: gpt-3.5-turbo\n\ndeployment: gpt-3.5-turbo\n\x1b[0m",
"litellm.acompletion(model=None)\x1b[31m Exception No deployments available for selected model, passed model=gpt-3.5-turbo\x1b[0m",
"Falling back to model_group = azure/gpt-3.5-turbo",
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",

View file

@ -7886,21 +7886,46 @@ def exception_type(
exception_type = type(original_exception).__name__
else:
exception_type = ""
_api_base = ""
try:
_api_base = litellm.get_api_base(
model=model, optional_params=extra_kwargs
)
except:
_api_base = ""
error_str += f" \n model: {model} \n api_base: {_api_base} \n"
error_str += str(completion_kwargs)
################################################################################
# Common Extra information needed for all providers
# We pass num retries, api_base, vertex_deployment etc to the exception here
################################################################################
_api_base = litellm.get_api_base(model=model, optional_params=extra_kwargs)
messages = litellm.get_first_chars_messages(kwargs=completion_kwargs)
_vertex_project = extra_kwargs.get("vertex_project")
_vertex_location = extra_kwargs.get("vertex_location")
_metadata = extra_kwargs.get("metadata", {}) or {}
_model_group = _metadata.get("model_group")
_deployment = _metadata.get("deployment")
extra_information = f"\nModel: {model}"
if _api_base:
extra_information += f"\nAPI Base: {_api_base}"
if messages and len(messages) > 0:
extra_information += f"\nMessages: {messages}"
if _model_group is not None:
extra_information += f"\nmodel_group: {_model_group}\n"
if _deployment is not None:
extra_information += f"\ndeployment: {_deployment}\n"
if _vertex_project is not None:
extra_information += f"\nvertex_project: {_vertex_project}\n"
if _vertex_location is not None:
extra_information += f"\nvertex_location: {_vertex_location}\n"
################################################################################
# End of Common Extra information Needed for all providers
################################################################################
################################################################################
#################### Start of Provider Exception mapping ####################
################################################################################
if "Request Timeout Error" in error_str or "Request timed out" in error_str:
exception_mapping_worked = True
raise Timeout(
message=f"APITimeoutError - Request timed out. \n model: {model} \n api_base: {_api_base} \n error_str: {error_str}",
message=f"APITimeoutError - Request timed out. {extra_information} \n error_str: {error_str}",
model=model,
llm_provider=custom_llm_provider,
)
@ -7935,7 +7960,7 @@ def exception_type(
):
exception_mapping_worked = True
raise ContextWindowExceededError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
@ -7946,7 +7971,7 @@ def exception_type(
):
exception_mapping_worked = True
raise NotFoundError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
@ -7957,7 +7982,7 @@ def exception_type(
):
exception_mapping_worked = True
raise ContentPolicyViolationError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
@ -7968,7 +7993,7 @@ def exception_type(
):
exception_mapping_worked = True
raise BadRequestError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
@ -7979,7 +8004,7 @@ def exception_type(
):
exception_mapping_worked = True
raise AuthenticationError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
@ -7991,7 +8016,7 @@ def exception_type(
)
raise APIError(
status_code=500,
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
request=_request,
@ -8001,7 +8026,7 @@ def exception_type(
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
@ -8009,7 +8034,7 @@ def exception_type(
elif original_exception.status_code == 404:
exception_mapping_worked = True
raise NotFoundError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
@ -8017,14 +8042,14 @@ def exception_type(
elif original_exception.status_code == 408:
exception_mapping_worked = True
raise Timeout(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
model=model,
llm_provider=custom_llm_provider,
)
elif original_exception.status_code == 422:
exception_mapping_worked = True
raise BadRequestError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
@ -8032,7 +8057,7 @@ def exception_type(
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
@ -8040,7 +8065,7 @@ def exception_type(
elif original_exception.status_code == 503:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
@ -8048,7 +8073,7 @@ def exception_type(
elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True
raise Timeout(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
model=model,
llm_provider=custom_llm_provider,
)
@ -8056,7 +8081,7 @@ def exception_type(
exception_mapping_worked = True
raise APIError(
status_code=original_exception.status_code,
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
request=original_exception.request,
@ -8064,7 +8089,7 @@ def exception_type(
else:
# if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors
raise APIConnectionError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
request=httpx.Request(
@ -8371,33 +8396,13 @@ def exception_type(
response=original_exception.response,
)
elif custom_llm_provider == "vertex_ai":
if completion_kwargs is not None:
# add model, deployment and model_group to the exception message
_model = completion_kwargs.get("model")
error_str += f"\nmodel: {_model}\n"
if extra_kwargs is not None:
_vertex_project = extra_kwargs.get("vertex_project")
_vertex_location = extra_kwargs.get("vertex_location")
_metadata = extra_kwargs.get("metadata", {}) or {}
_model_group = _metadata.get("model_group")
_deployment = _metadata.get("deployment")
if _model_group is not None:
error_str += f"model_group: {_model_group}\n"
if _deployment is not None:
error_str += f"deployment: {_deployment}\n"
if _vertex_project is not None:
error_str += f"vertex_project: {_vertex_project}\n"
if _vertex_location is not None:
error_str += f"vertex_location: {_vertex_location}\n"
if (
"Vertex AI API has not been used in project" in error_str
or "Unable to find your project" in error_str
):
exception_mapping_worked = True
raise BadRequestError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
model=model,
llm_provider="vertex_ai",
response=original_exception.response,
@ -8408,7 +8413,7 @@ def exception_type(
):
exception_mapping_worked = True
raise APIError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
status_code=500,
model=model,
llm_provider="vertex_ai",
@ -8417,7 +8422,7 @@ def exception_type(
elif "403" in error_str:
exception_mapping_worked = True
raise BadRequestError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
model=model,
llm_provider="vertex_ai",
response=original_exception.response,
@ -8425,7 +8430,7 @@ def exception_type(
elif "The response was blocked." in error_str:
exception_mapping_worked = True
raise UnprocessableEntityError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
model=model,
llm_provider="vertex_ai",
response=httpx.Response(
@ -8444,7 +8449,7 @@ def exception_type(
):
exception_mapping_worked = True
raise RateLimitError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
model=model,
llm_provider="vertex_ai",
response=httpx.Response(
@ -8459,7 +8464,7 @@ def exception_type(
if original_exception.status_code == 400:
exception_mapping_worked = True
raise BadRequestError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
model=model,
llm_provider="vertex_ai",
response=original_exception.response,
@ -8467,7 +8472,7 @@ def exception_type(
if original_exception.status_code == 500:
exception_mapping_worked = True
raise APIError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
status_code=500,
model=model,
llm_provider="vertex_ai",
@ -9061,16 +9066,6 @@ def exception_type(
request=original_exception.request,
)
elif custom_llm_provider == "azure":
_api_base = litellm.get_api_base(
model=model, optional_params=extra_kwargs
)
messages = litellm.get_first_chars_messages(kwargs=completion_kwargs)
extra_information = f"\nModel: {model}"
if _api_base:
extra_information += f"\nAPI Base: {_api_base}"
if messages and len(messages) > 0:
extra_information += f"\nMessages: {messages}"
if "Internal server error" in error_str:
exception_mapping_worked = True
raise APIError(