feat return num retries in exceptions

This commit is contained in:
Ishaan Jaff 2024-05-04 18:50:38 -07:00
parent 87e165e413
commit 0f03e53348

View file

@ -7886,21 +7886,51 @@ def exception_type(
exception_type = type(original_exception).__name__
else:
exception_type = ""
_api_base = ""
try:
_api_base = litellm.get_api_base(
model=model, optional_params=extra_kwargs
)
except:
_api_base = ""
error_str += f" \n model: {model} \n api_base: {_api_base} \n"
error_str += str(completion_kwargs)
################################################################################
# Common Extra information needed for all providers
# We pass num retries, api_base, vertex_deployment etc to the exception here
################################################################################
_api_base = litellm.get_api_base(model=model, optional_params=extra_kwargs)
messages = litellm.get_first_chars_messages(kwargs=completion_kwargs)
_previous_requests = extra_kwargs.get("previous_models", None)
_vertex_project = extra_kwargs.get("vertex_project")
_vertex_location = extra_kwargs.get("vertex_location")
_metadata = extra_kwargs.get("metadata", {}) or {}
_model_group = _metadata.get("model_group")
_deployment = _metadata.get("deployment")
num_retries = 0
if _previous_requests and isinstance(_previous_requests, list):
num_retries = len(_previous_requests)
extra_information = f"\nModel: {model}"
if _api_base:
extra_information += f"\nAPI Base: {_api_base}"
if messages and len(messages) > 0:
extra_information += f"\nMessages: {messages}"
extra_information += f"\nNum Retries: {num_retries}"
if _model_group is not None:
extra_information += f"\nmodel_group: {_model_group}\n"
if _deployment is not None:
extra_information += f"\ndeployment: {_deployment}\n"
if _vertex_project is not None:
extra_information += f"\nvertex_project: {_vertex_project}\n"
if _vertex_location is not None:
extra_information += f"\nvertex_location: {_vertex_location}\n"
################################################################################
# End of Common Extra information Needed for all providers
################################################################################
################################################################################
#################### Start of Provider Exception mapping ####################
################################################################################
if "Request Timeout Error" in error_str or "Request timed out" in error_str:
exception_mapping_worked = True
raise Timeout(
message=f"APITimeoutError - Request timed out. \n model: {model} \n api_base: {_api_base} \n error_str: {error_str}",
message=f"APITimeoutError - Request timed out. {extra_information} \n error_str: {error_str}",
model=model,
llm_provider=custom_llm_provider,
)
@ -7935,7 +7965,7 @@ def exception_type(
):
exception_mapping_worked = True
raise ContextWindowExceededError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
@ -7946,7 +7976,7 @@ def exception_type(
):
exception_mapping_worked = True
raise NotFoundError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
@ -7957,7 +7987,7 @@ def exception_type(
):
exception_mapping_worked = True
raise ContentPolicyViolationError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
@ -7968,7 +7998,7 @@ def exception_type(
):
exception_mapping_worked = True
raise BadRequestError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
@ -7979,7 +8009,7 @@ def exception_type(
):
exception_mapping_worked = True
raise AuthenticationError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
@ -7991,7 +8021,7 @@ def exception_type(
)
raise APIError(
status_code=500,
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
request=_request,
@ -8001,7 +8031,7 @@ def exception_type(
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
@ -8009,7 +8039,7 @@ def exception_type(
elif original_exception.status_code == 404:
exception_mapping_worked = True
raise NotFoundError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
@ -8017,14 +8047,14 @@ def exception_type(
elif original_exception.status_code == 408:
exception_mapping_worked = True
raise Timeout(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
model=model,
llm_provider=custom_llm_provider,
)
elif original_exception.status_code == 422:
exception_mapping_worked = True
raise BadRequestError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
@ -8032,7 +8062,7 @@ def exception_type(
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
@ -8040,7 +8070,7 @@ def exception_type(
elif original_exception.status_code == 503:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
@ -8048,7 +8078,7 @@ def exception_type(
elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True
raise Timeout(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
model=model,
llm_provider=custom_llm_provider,
)
@ -8056,7 +8086,7 @@ def exception_type(
exception_mapping_worked = True
raise APIError(
status_code=original_exception.status_code,
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
request=original_exception.request,
@ -8064,7 +8094,7 @@ def exception_type(
else:
# if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors
raise APIConnectionError(
message=f"{exception_provider} - {message}",
message=f"{exception_provider} - {message} {extra_information}",
llm_provider=custom_llm_provider,
model=model,
request=httpx.Request(
@ -8371,33 +8401,13 @@ def exception_type(
response=original_exception.response,
)
elif custom_llm_provider == "vertex_ai":
if completion_kwargs is not None:
# add model, deployment and model_group to the exception message
_model = completion_kwargs.get("model")
error_str += f"\nmodel: {_model}\n"
if extra_kwargs is not None:
_vertex_project = extra_kwargs.get("vertex_project")
_vertex_location = extra_kwargs.get("vertex_location")
_metadata = extra_kwargs.get("metadata", {}) or {}
_model_group = _metadata.get("model_group")
_deployment = _metadata.get("deployment")
if _model_group is not None:
error_str += f"model_group: {_model_group}\n"
if _deployment is not None:
error_str += f"deployment: {_deployment}\n"
if _vertex_project is not None:
error_str += f"vertex_project: {_vertex_project}\n"
if _vertex_location is not None:
error_str += f"vertex_location: {_vertex_location}\n"
if (
"Vertex AI API has not been used in project" in error_str
or "Unable to find your project" in error_str
):
exception_mapping_worked = True
raise BadRequestError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
model=model,
llm_provider="vertex_ai",
response=original_exception.response,
@ -8408,7 +8418,7 @@ def exception_type(
):
exception_mapping_worked = True
raise APIError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
status_code=500,
model=model,
llm_provider="vertex_ai",
@ -8417,7 +8427,7 @@ def exception_type(
elif "403" in error_str:
exception_mapping_worked = True
raise BadRequestError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
model=model,
llm_provider="vertex_ai",
response=original_exception.response,
@ -8425,7 +8435,7 @@ def exception_type(
elif "The response was blocked." in error_str:
exception_mapping_worked = True
raise UnprocessableEntityError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
model=model,
llm_provider="vertex_ai",
response=httpx.Response(
@ -8444,7 +8454,7 @@ def exception_type(
):
exception_mapping_worked = True
raise RateLimitError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
model=model,
llm_provider="vertex_ai",
response=httpx.Response(
@ -8459,7 +8469,7 @@ def exception_type(
if original_exception.status_code == 400:
exception_mapping_worked = True
raise BadRequestError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
model=model,
llm_provider="vertex_ai",
response=original_exception.response,
@ -8467,7 +8477,7 @@ def exception_type(
if original_exception.status_code == 500:
exception_mapping_worked = True
raise APIError(
message=f"VertexAIException - {error_str}",
message=f"VertexAIException - {error_str} {extra_information}",
status_code=500,
model=model,
llm_provider="vertex_ai",
@ -9061,16 +9071,6 @@ def exception_type(
request=original_exception.request,
)
elif custom_llm_provider == "azure":
_api_base = litellm.get_api_base(
model=model, optional_params=extra_kwargs
)
messages = litellm.get_first_chars_messages(kwargs=completion_kwargs)
extra_information = f"\nModel: {model}"
if _api_base:
extra_information += f"\nAPI Base: {_api_base}"
if messages and len(messages) > 0:
extra_information += f"\nMessages: {messages}"
if "Internal server error" in error_str:
exception_mapping_worked = True
raise APIError(