mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Merge pull request #5358 from BerriAI/litellm_fix_retry_after
fix retry after - cooldown individual models based on their specific 'retry-after' header
This commit is contained in:
commit
415abc86c6
12 changed files with 754 additions and 202 deletions
|
@ -638,7 +638,10 @@ def client(original_function):
|
|||
if is_coroutine is True:
|
||||
pass
|
||||
else:
|
||||
if isinstance(original_response, ModelResponse):
|
||||
if (
|
||||
isinstance(original_response, ModelResponse)
|
||||
and len(original_response.choices) > 0
|
||||
):
|
||||
model_response: Optional[str] = original_response.choices[
|
||||
0
|
||||
].message.content # type: ignore
|
||||
|
@ -6382,6 +6385,7 @@ def _get_retry_after_from_exception_header(
|
|||
retry_after = int(retry_date - time.time())
|
||||
else:
|
||||
retry_after = -1
|
||||
|
||||
return retry_after
|
||||
|
||||
except Exception as e:
|
||||
|
@ -6563,6 +6567,40 @@ def get_model_list():
|
|||
|
||||
|
||||
####### EXCEPTION MAPPING ################
|
||||
def _get_litellm_response_headers(
|
||||
original_exception: Exception,
|
||||
) -> Optional[httpx.Headers]:
|
||||
"""
|
||||
Extract and return the response headers from a mapped exception, if present.
|
||||
|
||||
Used for accurate retry logic.
|
||||
"""
|
||||
_response_headers: Optional[httpx.Headers] = None
|
||||
try:
|
||||
_response_headers = getattr(
|
||||
original_exception, "litellm_response_headers", None
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return _response_headers
|
||||
|
||||
|
||||
def _get_response_headers(original_exception: Exception) -> Optional[httpx.Headers]:
|
||||
"""
|
||||
Extract and return the response headers from an exception, if present.
|
||||
|
||||
Used for accurate retry logic.
|
||||
"""
|
||||
_response_headers: Optional[httpx.Headers] = None
|
||||
try:
|
||||
_response_headers = getattr(original_exception, "headers", None)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return _response_headers
|
||||
|
||||
|
||||
def exception_type(
|
||||
model,
|
||||
original_exception,
|
||||
|
@ -6587,6 +6625,10 @@ def exception_type(
|
|||
"LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa
|
||||
) # noqa
|
||||
print() # noqa
|
||||
|
||||
litellm_response_headers = _get_response_headers(
|
||||
original_exception=original_exception
|
||||
)
|
||||
try:
|
||||
if model:
|
||||
if hasattr(original_exception, "message"):
|
||||
|
@ -6841,7 +6883,7 @@ def exception_type(
|
|||
message=f"{exception_provider} - {message}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
response=original_exception.response,
|
||||
response=getattr(original_exception, "response", None),
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
elif original_exception.status_code == 429:
|
||||
|
@ -6850,7 +6892,7 @@ def exception_type(
|
|||
message=f"RateLimitError: {exception_provider} - {message}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
response=original_exception.response,
|
||||
response=getattr(original_exception, "response", None),
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
elif original_exception.status_code == 503:
|
||||
|
@ -6859,7 +6901,7 @@ def exception_type(
|
|||
message=f"ServiceUnavailableError: {exception_provider} - {message}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
response=original_exception.response,
|
||||
response=getattr(original_exception, "response", None),
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
elif original_exception.status_code == 504: # gateway timeout error
|
||||
|
@ -6877,7 +6919,7 @@ def exception_type(
|
|||
message=f"APIError: {exception_provider} - {message}",
|
||||
llm_provider=custom_llm_provider,
|
||||
model=model,
|
||||
request=original_exception.request,
|
||||
request=getattr(original_exception, "request", None),
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
else:
|
||||
|
@ -8165,7 +8207,7 @@ def exception_type(
|
|||
model=model,
|
||||
request=original_exception.request,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
elif custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
|
||||
message = get_error_message(error_obj=original_exception)
|
||||
if message is None:
|
||||
if hasattr(original_exception, "message"):
|
||||
|
@ -8469,20 +8511,20 @@ def exception_type(
|
|||
threading.Thread(target=get_all_keys, args=(e.llm_provider,)).start()
|
||||
# don't let an error with mapping interrupt the user from receiving an error from the llm api calls
|
||||
if exception_mapping_worked:
|
||||
setattr(e, "litellm_response_headers", litellm_response_headers)
|
||||
raise e
|
||||
else:
|
||||
for error_type in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
if isinstance(e, error_type):
|
||||
setattr(e, "litellm_response_headers", litellm_response_headers)
|
||||
raise e # it's already mapped
|
||||
raise APIConnectionError(
|
||||
raised_exc = APIConnectionError(
|
||||
message="{}\n{}".format(original_exception, traceback.format_exc()),
|
||||
llm_provider="",
|
||||
model="",
|
||||
request=httpx.Request(
|
||||
method="POST",
|
||||
url="https://www.litellm.ai/",
|
||||
),
|
||||
)
|
||||
setattr(raised_exc, "litellm_response_headers", _response_headers)
|
||||
raise raised_exc
|
||||
|
||||
|
||||
######### Secret Manager ############################
|
||||
|
@ -10916,10 +10958,17 @@ class CustomStreamWrapper:
|
|||
|
||||
|
||||
class TextCompletionStreamWrapper:
|
||||
def __init__(self, completion_stream, model, stream_options: Optional[dict] = None):
|
||||
def __init__(
|
||||
self,
|
||||
completion_stream,
|
||||
model,
|
||||
stream_options: Optional[dict] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
):
|
||||
self.completion_stream = completion_stream
|
||||
self.model = model
|
||||
self.stream_options = stream_options
|
||||
self.custom_llm_provider = custom_llm_provider
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
@ -10970,7 +11019,13 @@ class TextCompletionStreamWrapper:
|
|||
except StopIteration:
|
||||
raise StopIteration
|
||||
except Exception as e:
|
||||
print(f"got exception {e}") # noqa
|
||||
raise exception_type(
|
||||
model=self.model,
|
||||
custom_llm_provider=self.custom_llm_provider or "",
|
||||
original_exception=e,
|
||||
completion_kwargs={},
|
||||
extra_kwargs={},
|
||||
)
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue