Merge pull request #5358 from BerriAI/litellm_fix_retry_after

fix retry after - cooldown individual models based on their specific 'retry-after' header
This commit is contained in:
Krish Dholakia 2024-08-27 11:50:14 -07:00 committed by GitHub
commit 415abc86c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 754 additions and 202 deletions

View file

@ -638,7 +638,10 @@ def client(original_function):
if is_coroutine is True:
pass
else:
if isinstance(original_response, ModelResponse):
if (
isinstance(original_response, ModelResponse)
and len(original_response.choices) > 0
):
model_response: Optional[str] = original_response.choices[
0
].message.content # type: ignore
@ -6382,6 +6385,7 @@ def _get_retry_after_from_exception_header(
retry_after = int(retry_date - time.time())
else:
retry_after = -1
return retry_after
except Exception as e:
@ -6563,6 +6567,40 @@ def get_model_list():
####### EXCEPTION MAPPING ################
def _get_litellm_response_headers(
original_exception: Exception,
) -> Optional[httpx.Headers]:
"""
Extract and return the response headers from a mapped exception, if present.
Used for accurate retry logic.
"""
_response_headers: Optional[httpx.Headers] = None
try:
_response_headers = getattr(
original_exception, "litellm_response_headers", None
)
except Exception:
return None
return _response_headers
def _get_response_headers(original_exception: Exception) -> Optional[httpx.Headers]:
"""
Extract and return the response headers from an exception, if present.
Used for accurate retry logic.
"""
_response_headers: Optional[httpx.Headers] = None
try:
_response_headers = getattr(original_exception, "headers", None)
except Exception:
return None
return _response_headers
def exception_type(
model,
original_exception,
@ -6587,6 +6625,10 @@ def exception_type(
"LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa
) # noqa
print() # noqa
litellm_response_headers = _get_response_headers(
original_exception=original_exception
)
try:
if model:
if hasattr(original_exception, "message"):
@ -6841,7 +6883,7 @@ def exception_type(
message=f"{exception_provider} - {message}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 429:
@ -6850,7 +6892,7 @@ def exception_type(
message=f"RateLimitError: {exception_provider} - {message}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 503:
@ -6859,7 +6901,7 @@ def exception_type(
message=f"ServiceUnavailableError: {exception_provider} - {message}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 504: # gateway timeout error
@ -6877,7 +6919,7 @@ def exception_type(
message=f"APIError: {exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
request=original_exception.request,
request=getattr(original_exception, "request", None),
litellm_debug_info=extra_information,
)
else:
@ -8165,7 +8207,7 @@ def exception_type(
model=model,
request=original_exception.request,
)
elif custom_llm_provider == "azure":
elif custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
message = get_error_message(error_obj=original_exception)
if message is None:
if hasattr(original_exception, "message"):
@ -8469,20 +8511,20 @@ def exception_type(
threading.Thread(target=get_all_keys, args=(e.llm_provider,)).start()
# don't let an error with mapping interrupt the user from receiving an error from the llm api calls
if exception_mapping_worked:
setattr(e, "litellm_response_headers", litellm_response_headers)
raise e
else:
for error_type in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, error_type):
setattr(e, "litellm_response_headers", litellm_response_headers)
raise e # it's already mapped
raise APIConnectionError(
raised_exc = APIConnectionError(
message="{}\n{}".format(original_exception, traceback.format_exc()),
llm_provider="",
model="",
request=httpx.Request(
method="POST",
url="https://www.litellm.ai/",
),
)
setattr(raised_exc, "litellm_response_headers", _response_headers)
raise raised_exc
######### Secret Manager ############################
@ -10916,10 +10958,17 @@ class CustomStreamWrapper:
class TextCompletionStreamWrapper:
def __init__(self, completion_stream, model, stream_options: Optional[dict] = None):
def __init__(
self,
completion_stream,
model,
stream_options: Optional[dict] = None,
custom_llm_provider: Optional[str] = None,
):
self.completion_stream = completion_stream
self.model = model
self.stream_options = stream_options
self.custom_llm_provider = custom_llm_provider
def __iter__(self):
return self
@ -10970,7 +11019,13 @@ class TextCompletionStreamWrapper:
except StopIteration:
raise StopIteration
except Exception as e:
print(f"got exception {e}") # noqa
raise exception_type(
model=self.model,
custom_llm_provider=self.custom_llm_provider or "",
original_exception=e,
completion_kwargs={},
extra_kwargs={},
)
async def __anext__(self):
try: